MLIR  22.0.0git
Deserializer.cpp
Go to the documentation of this file.
1 //===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Deserializer.h"
14 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/IRMapping.h"
21 #include "mlir/IR/Location.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
31 #include <optional>
32 
33 using namespace mlir;
34 
35 #define DEBUG_TYPE "spirv-deserialization"
36 
37 //===----------------------------------------------------------------------===//
38 // Utility Functions
39 //===----------------------------------------------------------------------===//
40 
41 /// Returns true if the given `block` is a function entry block.
42 static inline bool isFnEntryBlock(Block *block) {
43  return block->isEntryBlock() &&
44  isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Deserializer Method Definitions
49 //===----------------------------------------------------------------------===//
50 
52  MLIRContext *context,
54  : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55  module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56 #ifndef NDEBUG
57  ,
58  logger(llvm::dbgs())
59 #endif
60 {
61 }
62 
64  LLVM_DEBUG({
65  logger.resetIndent();
66  logger.startLine()
67  << "//+++---------- start deserialization ----------+++//\n";
68  });
69 
70  if (failed(processHeader()))
71  return failure();
72 
73  spirv::Opcode opcode = spirv::Opcode::OpNop;
74  ArrayRef<uint32_t> operands;
75  auto binarySize = binary.size();
76  while (curOffset < binarySize) {
77  // Slice the next instruction out and populate `opcode` and `operands`.
78  // Internally this also updates `curOffset`.
79  if (failed(sliceInstruction(opcode, operands)))
80  return failure();
81 
82  if (failed(processInstruction(opcode, operands)))
83  return failure();
84  }
85 
86  assert(curOffset == binarySize &&
87  "deserializer should never index beyond the binary end");
88 
89  for (auto &deferred : deferredInstructions) {
90  if (failed(processInstruction(deferred.first, deferred.second, false))) {
91  return failure();
92  }
93  }
94 
95  attachVCETriple();
96 
97  LLVM_DEBUG(logger.startLine()
98  << "//+++-------- completed deserialization --------+++//\n");
99  return success();
100 }
101 
103  return std::move(module);
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Module structure
108 //===----------------------------------------------------------------------===//
109 
110 OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111  OpBuilder builder(context);
112  OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113  spirv::ModuleOp::build(builder, state);
114  return cast<spirv::ModuleOp>(Operation::create(state));
115 }
116 
117 LogicalResult spirv::Deserializer::processHeader() {
118  if (binary.size() < spirv::kHeaderWordCount)
119  return emitError(unknownLoc,
120  "SPIR-V binary module must have a 5-word header");
121 
122  if (binary[0] != spirv::kMagicNumber)
123  return emitError(unknownLoc, "incorrect magic number");
124 
125  // Version number bytes: 0 | major number | minor number | 0
126  uint32_t majorVersion = (binary[1] << 8) >> 24;
127  uint32_t minorVersion = (binary[1] << 16) >> 24;
128  if (majorVersion == 1) {
129  switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
131  case v: \
132  version = spirv::Version::V_1_##v; \
133  break
134 
135  MIN_VERSION_CASE(0);
136  MIN_VERSION_CASE(1);
137  MIN_VERSION_CASE(2);
138  MIN_VERSION_CASE(3);
139  MIN_VERSION_CASE(4);
140  MIN_VERSION_CASE(5);
141  MIN_VERSION_CASE(6);
142 #undef MIN_VERSION_CASE
143  default:
144  return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145  << minorVersion;
146  }
147  } else {
148  return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149  << majorVersion;
150  }
151 
152  // TODO: generator number, bound, schema
153  curOffset = spirv::kHeaderWordCount;
154  return success();
155 }
156 
157 LogicalResult
158 spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159  if (operands.size() != 1)
160  return emitError(unknownLoc, "OpCapability must have one parameter");
161 
162  auto cap = spirv::symbolizeCapability(operands[0]);
163  if (!cap)
164  return emitError(unknownLoc, "unknown capability: ") << operands[0];
165 
166  capabilities.insert(*cap);
167  return success();
168 }
169 
170 LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171  if (words.empty()) {
172  return emitError(
173  unknownLoc,
174  "OpExtension must have a literal string for the extension name");
175  }
176 
177  unsigned wordIndex = 0;
178  StringRef extName = decodeStringLiteral(words, wordIndex);
179  if (wordIndex != words.size())
180  return emitError(unknownLoc,
181  "unexpected trailing words in OpExtension instruction");
182  auto ext = spirv::symbolizeExtension(extName);
183  if (!ext)
184  return emitError(unknownLoc, "unknown extension: ") << extName;
185 
186  extensions.insert(*ext);
187  return success();
188 }
189 
190 LogicalResult
191 spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192  if (words.size() < 2) {
193  return emitError(unknownLoc,
194  "OpExtInstImport must have a result <id> and a literal "
195  "string for the extended instruction set name");
196  }
197 
198  unsigned wordIndex = 1;
199  extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200  if (wordIndex != words.size()) {
201  return emitError(unknownLoc,
202  "unexpected trailing words in OpExtInstImport");
203  }
204  return success();
205 }
206 
207 void spirv::Deserializer::attachVCETriple() {
208  (*module)->setAttr(
209  spirv::ModuleOp::getVCETripleAttrName(),
210  spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211  extensions.getArrayRef(), context));
212 }
213 
214 LogicalResult
215 spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216  if (operands.size() != 2)
217  return emitError(unknownLoc, "OpMemoryModel must have two operands");
218 
219  (*module)->setAttr(
220  module->getAddressingModelAttrName(),
221  opBuilder.getAttr<spirv::AddressingModelAttr>(
222  static_cast<spirv::AddressingModel>(operands.front())));
223 
224  (*module)->setAttr(module->getMemoryModelAttrName(),
225  opBuilder.getAttr<spirv::MemoryModelAttr>(
226  static_cast<spirv::MemoryModel>(operands.back())));
227 
228  return success();
229 }
230 
231 template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233  Location loc, OpBuilder &opBuilder,
235  StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236  if (words.size() != 4) {
237  return emitError(loc, "OpDecoration with ")
238  << decorationName << "needs a cache control integer literal and a "
239  << cacheControlKind << " cache control literal";
240  }
241  unsigned cacheLevel = words[2];
242  auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243  auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245  if (auto attrList =
246  llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247  llvm::append_range(attrs, attrList);
248  attrs.push_back(value);
249  decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250  return success();
251 }
252 
253 LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254  // TODO: This function should also be auto-generated. For now, since only a
255  // few decorations are processed/handled in a meaningful manner, going with a
256  // manual implementation.
257  if (words.size() < 2) {
258  return emitError(
259  unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260  }
261  auto decorationName =
262  stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263  if (decorationName.empty()) {
264  return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265  }
266  auto symbol = getSymbolDecoration(decorationName);
267  switch (static_cast<spirv::Decoration>(words[1])) {
268  case spirv::Decoration::FPFastMathMode:
269  if (words.size() != 3) {
270  return emitError(unknownLoc, "OpDecorate with ")
271  << decorationName << " needs a single integer literal";
272  }
273  decorations[words[0]].set(
274  symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275  static_cast<FPFastMathMode>(words[2])));
276  break;
277  case spirv::Decoration::FPRoundingMode:
278  if (words.size() != 3) {
279  return emitError(unknownLoc, "OpDecorate with ")
280  << decorationName << " needs a single integer literal";
281  }
282  decorations[words[0]].set(
283  symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284  static_cast<FPRoundingMode>(words[2])));
285  break;
286  case spirv::Decoration::DescriptorSet:
287  case spirv::Decoration::Binding:
288  if (words.size() != 3) {
289  return emitError(unknownLoc, "OpDecorate with ")
290  << decorationName << " needs a single integer literal";
291  }
292  decorations[words[0]].set(
293  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
294  break;
295  case spirv::Decoration::BuiltIn:
296  if (words.size() != 3) {
297  return emitError(unknownLoc, "OpDecorate with ")
298  << decorationName << " needs a single integer literal";
299  }
300  decorations[words[0]].set(
301  symbol, opBuilder.getStringAttr(
302  stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
303  break;
304  case spirv::Decoration::ArrayStride:
305  if (words.size() != 3) {
306  return emitError(unknownLoc, "OpDecorate with ")
307  << decorationName << " needs a single integer literal";
308  }
309  typeDecorations[words[0]] = words[2];
310  break;
311  case spirv::Decoration::LinkageAttributes: {
312  if (words.size() < 4) {
313  return emitError(unknownLoc, "OpDecorate with ")
314  << decorationName
315  << " needs at least 1 string and 1 integer literal";
316  }
317  // LinkageAttributes has two parameters ["linkageName", linkageType]
318  // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
319  // "linkageName" is a stringliteral encoded as uint32_t,
320  // hence the size of name is variable length which results in words.size()
321  // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
322  // 3 + ceildiv(strlen(name), 4).
323  unsigned wordIndex = 2;
324  auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
325  auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326  static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
327  auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328  StringAttr::get(context, linkageName), linkageTypeAttr);
329  decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
330  break;
331  }
332  case spirv::Decoration::Aliased:
333  case spirv::Decoration::AliasedPointer:
334  case spirv::Decoration::Block:
335  case spirv::Decoration::BufferBlock:
336  case spirv::Decoration::Flat:
337  case spirv::Decoration::NonReadable:
338  case spirv::Decoration::NonWritable:
339  case spirv::Decoration::NoPerspective:
340  case spirv::Decoration::NoSignedWrap:
341  case spirv::Decoration::NoUnsignedWrap:
342  case spirv::Decoration::RelaxedPrecision:
343  case spirv::Decoration::Restrict:
344  case spirv::Decoration::RestrictPointer:
345  case spirv::Decoration::NoContraction:
346  case spirv::Decoration::Constant:
347  case spirv::Decoration::Invariant:
348  case spirv::Decoration::Patch:
349  if (words.size() != 2) {
350  return emitError(unknownLoc, "OpDecoration with ")
351  << decorationName << "needs a single target <id>";
352  }
353  decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
354  break;
355  case spirv::Decoration::Location:
356  case spirv::Decoration::SpecId:
357  if (words.size() != 3) {
358  return emitError(unknownLoc, "OpDecoration with ")
359  << decorationName << "needs a single integer literal";
360  }
361  decorations[words[0]].set(
362  symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
363  break;
364  case spirv::Decoration::CacheControlLoadINTEL: {
365  LogicalResult res = deserializeCacheControlDecoration<
366  CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
368  "load");
369  if (failed(res))
370  return res;
371  break;
372  }
373  case spirv::Decoration::CacheControlStoreINTEL: {
374  LogicalResult res = deserializeCacheControlDecoration<
375  CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376  unknownLoc, opBuilder, decorations, words, symbol, decorationName,
377  "store");
378  if (failed(res))
379  return res;
380  break;
381  }
382  default:
383  return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
384  }
385  return success();
386 }
387 
388 LogicalResult
389 spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
390  // The binary layout of OpMemberDecorate is different comparing to OpDecorate
391  if (words.size() < 3) {
392  return emitError(unknownLoc,
393  "OpMemberDecorate must have at least 3 operands");
394  }
395 
396  auto decoration = static_cast<spirv::Decoration>(words[2]);
397  if (decoration == spirv::Decoration::Offset && words.size() != 4) {
398  return emitError(unknownLoc,
399  " missing offset specification in OpMemberDecorate with "
400  "Offset decoration");
401  }
402  ArrayRef<uint32_t> decorationOperands;
403  if (words.size() > 3) {
404  decorationOperands = words.slice(3);
405  }
406  memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
407  return success();
408 }
409 
410 LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411  if (words.size() < 3) {
412  return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
413  }
414  unsigned wordIndex = 2;
415  auto name = decodeStringLiteral(words, wordIndex);
416  if (wordIndex != words.size()) {
417  return emitError(unknownLoc,
418  "unexpected trailing words in OpMemberName instruction");
419  }
420  memberNameMap[words[0]][words[1]] = name;
421  return success();
422 }
423 
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
425  uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
426  if (!decorations.contains(argID)) {
427  argAttrs[argIndex] = DictionaryAttr::get(context, {});
428  return success();
429  }
430 
431  spirv::DecorationAttr foundDecorationAttr;
432  for (NamedAttribute decAttr : decorations[argID]) {
433  for (auto decoration :
434  {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435  spirv::Decoration::AliasedPointer,
436  spirv::Decoration::RestrictPointer}) {
437 
438  if (decAttr.getName() !=
439  getSymbolDecoration(stringifyDecoration(decoration)))
440  continue;
441 
442  if (foundDecorationAttr)
443  return emitError(unknownLoc,
444  "more than one Aliased/Restrict decorations for "
445  "function argument with result <id> ")
446  << argID;
447 
448  foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
449  break;
450  }
451 
452  if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
453  spirv::Decoration::RelaxedPrecision))) {
454  // TODO: Current implementation supports only one decoration per function
455  // parameter so RelaxedPrecision cannot be applied at the same time as,
456  // for example, Aliased/Restrict/etc. This should be relaxed to allow any
457  // combination of decoration allowed by the spec to be supported.
458  if (foundDecorationAttr)
459  return emitError(unknownLoc, "already found a decoration for function "
460  "argument with result <id> ")
461  << argID;
462 
463  foundDecorationAttr = spirv::DecorationAttr::get(
464  context, spirv::Decoration::RelaxedPrecision);
465  }
466  }
467 
468  if (!foundDecorationAttr)
469  return emitError(unknownLoc, "unimplemented decoration support for "
470  "function argument with result <id> ")
471  << argID;
472 
473  NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
474  foundDecorationAttr);
475  argAttrs[argIndex] = DictionaryAttr::get(context, attr);
476  return success();
477 }
478 
479 LogicalResult
480 spirv::Deserializer::processFunction(ArrayRef<uint32_t> operands) {
481  if (curFunction) {
482  return emitError(unknownLoc, "found function inside function");
483  }
484 
485  // Get the result type
486  if (operands.size() != 4) {
487  return emitError(unknownLoc, "OpFunction must have 4 parameters");
488  }
489  Type resultType = getType(operands[0]);
490  if (!resultType) {
491  return emitError(unknownLoc, "undefined result type from <id> ")
492  << operands[0];
493  }
494 
495  uint32_t fnID = operands[1];
496  if (funcMap.count(fnID)) {
497  return emitError(unknownLoc, "duplicate function definition/declaration");
498  }
499 
500  auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
501  if (!fnControl) {
502  return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
503  }
504 
505  Type fnType = getType(operands[3]);
506  if (!fnType || !isa<FunctionType>(fnType)) {
507  return emitError(unknownLoc, "unknown function type from <id> ")
508  << operands[3];
509  }
510  auto functionType = cast<FunctionType>(fnType);
511 
512  if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
513  (functionType.getNumResults() == 1 &&
514  functionType.getResult(0) != resultType)) {
515  return emitError(unknownLoc, "mismatch in function type ")
516  << functionType << " and return type " << resultType << " specified";
517  }
518 
519  std::string fnName = getFunctionSymbol(fnID);
520  auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
521  functionType, fnControl.value());
522  // Processing other function attributes.
523  if (decorations.count(fnID)) {
524  for (auto attr : decorations[fnID].getAttrs()) {
525  funcOp->setAttr(attr.getName(), attr.getValue());
526  }
527  }
528  curFunction = funcMap[fnID] = funcOp;
529  auto *entryBlock = funcOp.addEntryBlock();
530  LLVM_DEBUG({
531  logger.startLine()
532  << "//===-------------------------------------------===//\n";
533  logger.startLine() << "[fn] name: " << fnName << "\n";
534  logger.startLine() << "[fn] type: " << fnType << "\n";
535  logger.startLine() << "[fn] ID: " << fnID << "\n";
536  logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
537  logger.indent();
538  });
539 
540  SmallVector<Attribute> argAttrs;
541  argAttrs.resize(functionType.getNumInputs());
542 
543  // Parse the op argument instructions
544  if (functionType.getNumInputs()) {
545  for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
546  auto argType = functionType.getInput(i);
547  spirv::Opcode opcode = spirv::Opcode::OpNop;
548  ArrayRef<uint32_t> operands;
549  if (failed(sliceInstruction(opcode, operands,
550  spirv::Opcode::OpFunctionParameter))) {
551  return failure();
552  }
553  if (opcode != spirv::Opcode::OpFunctionParameter) {
554  return emitError(
555  unknownLoc,
556  "missing OpFunctionParameter instruction for argument ")
557  << i;
558  }
559  if (operands.size() != 2) {
560  return emitError(
561  unknownLoc,
562  "expected result type and result <id> for OpFunctionParameter");
563  }
564  auto argDefinedType = getType(operands[0]);
565  if (!argDefinedType || argDefinedType != argType) {
566  return emitError(unknownLoc,
567  "mismatch in argument type between function type "
568  "definition ")
569  << functionType << " and argument type definition "
570  << argDefinedType << " at argument " << i;
571  }
572  if (getValue(operands[1])) {
573  return emitError(unknownLoc, "duplicate definition of result <id> ")
574  << operands[1];
575  }
576  if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
577  return failure();
578  }
579 
580  auto argValue = funcOp.getArgument(i);
581  valueMap[operands[1]] = argValue;
582  }
583  }
584 
585  if (llvm::any_of(argAttrs, [](Attribute attr) {
586  auto argAttr = cast<DictionaryAttr>(attr);
587  return !argAttr.empty();
588  }))
589  funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
590 
591  // entryBlock is needed to access the arguments, Once that is done, we can
592  // erase the block for functions with 'Import' LinkageAttributes, since these
593  // are essentially function declarations, so they have no body.
594  auto linkageAttr = funcOp.getLinkageAttributes();
595  auto hasImportLinkage =
596  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
597  spirv::LinkageType::Import);
598  if (hasImportLinkage)
599  funcOp.eraseBody();
600 
601  // RAII guard to reset the insertion point to the module's region after
602  // deserializing the body of this function.
603  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
604 
605  spirv::Opcode opcode = spirv::Opcode::OpNop;
606  ArrayRef<uint32_t> instOperands;
607 
608  // Special handling for the entry block. We need to make sure it starts with
609  // an OpLabel instruction. The entry block takes the same parameters as the
610  // function. All other blocks do not take any parameter. We have already
611  // created the entry block, here we need to register it to the correct label
612  // <id>.
613  if (failed(sliceInstruction(opcode, instOperands,
614  spirv::Opcode::OpFunctionEnd))) {
615  return failure();
616  }
617  if (opcode == spirv::Opcode::OpFunctionEnd) {
618  return processFunctionEnd(instOperands);
619  }
620  if (opcode != spirv::Opcode::OpLabel) {
621  return emitError(unknownLoc, "a basic block must start with OpLabel");
622  }
623  if (instOperands.size() != 1) {
624  return emitError(unknownLoc, "OpLabel should only have result <id>");
625  }
626  blockMap[instOperands[0]] = entryBlock;
627  if (failed(processLabel(instOperands))) {
628  return failure();
629  }
630 
631  // Then process all the other instructions in the function until we hit
632  // OpFunctionEnd.
633  while (succeeded(sliceInstruction(opcode, instOperands,
634  spirv::Opcode::OpFunctionEnd)) &&
635  opcode != spirv::Opcode::OpFunctionEnd) {
636  if (failed(processInstruction(opcode, instOperands))) {
637  return failure();
638  }
639  }
640  if (opcode != spirv::Opcode::OpFunctionEnd) {
641  return failure();
642  }
643 
644  return processFunctionEnd(instOperands);
645 }
646 
647 LogicalResult
648 spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
649  // Process OpFunctionEnd.
650  if (!operands.empty()) {
651  return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
652  }
653 
654  // Wire up block arguments from OpPhi instructions.
655  // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
656  // ops.
657  if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
658  return failure();
659  }
660 
661  curBlock = nullptr;
662  curFunction = std::nullopt;
663 
664  LLVM_DEBUG({
665  logger.unindent();
666  logger.startLine()
667  << "//===-------------------------------------------===//\n";
668  });
669  return success();
670 }
671 
672 LogicalResult
673 spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
674  if (operands.size() < 2) {
675  return emitError(unknownLoc,
676  "missing graph defintion in OpGraphEntryPointARM");
677  }
678 
679  unsigned wordIndex = 0;
680  uint32_t graphID = operands[wordIndex++];
681  if (!graphMap.contains(graphID)) {
682  return emitError(unknownLoc,
683  "missing graph definition/declaration with id ")
684  << graphID;
685  }
686 
687  spirv::GraphARMOp graphARM = graphMap[graphID];
688  StringRef name = decodeStringLiteral(operands, wordIndex);
689  graphARM.setSymName(name);
690  graphARM.setEntryPoint(true);
691 
692  SmallVector<Attribute, 4> interface;
693  for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
694  if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
695  interface.push_back(SymbolRefAttr::get(arg.getOperation()));
696  } else {
697  return emitError(unknownLoc, "undefined result <id> ")
698  << operands[wordIndex] << " while decoding OpGraphEntryPoint";
699  }
700  }
701 
702  // RAII guard to reset the insertion point to previous value when done.
703  OpBuilder::InsertionGuard insertionGuard(opBuilder);
704  opBuilder.setInsertionPoint(graphARM);
705  opBuilder.create<spirv::GraphEntryPointARMOp>(
706  unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
707  opBuilder.getArrayAttr(interface));
708 
709  return success();
710 }
711 
712 LogicalResult
713 spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
714  if (curGraph) {
715  return emitError(unknownLoc, "found graph inside graph");
716  }
717  // Get the result type.
718  if (operands.size() < 2) {
719  return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
720  }
721 
722  Type type = getType(operands[0]);
723  if (!type || !isa<GraphType>(type)) {
724  return emitError(unknownLoc, "unknown graph type from <id> ")
725  << operands[0];
726  }
727  auto graphType = cast<GraphType>(type);
728  if (graphType.getNumResults() <= 0) {
729  return emitError(unknownLoc, "expected at least one result");
730  }
731 
732  uint32_t graphID = operands[1];
733  if (graphMap.count(graphID)) {
734  return emitError(unknownLoc, "duplicate graph definition/declaration");
735  }
736 
737  std::string graphName = getGraphSymbol(graphID);
738  auto graphOp =
739  opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
740  curGraph = graphMap[graphID] = graphOp;
741  Block *entryBlock = graphOp.addEntryBlock();
742  LLVM_DEBUG({
743  logger.startLine()
744  << "//===-------------------------------------------===//\n";
745  logger.startLine() << "[graph] name: " << graphName << "\n";
746  logger.startLine() << "[graph] type: " << graphType << "\n";
747  logger.startLine() << "[graph] ID: " << graphID << "\n";
748  logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
749  logger.indent();
750  });
751 
752  // Parse the op argument instructions.
753  for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
754  spirv::Opcode opcode;
755  ArrayRef<uint32_t> operands;
756  if (failed(sliceInstruction(opcode, operands,
757  spirv::Opcode::OpGraphInputARM))) {
758  return failure();
759  }
760  if (operands.size() != 3) {
761  return emitError(unknownLoc, "expected result type, result <id> and "
762  "input index for OpGraphInputARM");
763  }
764 
765  Type argDefinedType = getType(operands[0]);
766  if (!argDefinedType) {
767  return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
768  }
769 
770  if (argDefinedType != argType) {
771  return emitError(unknownLoc,
772  "mismatch in argument type between graph type "
773  "definition ")
774  << graphType << " and argument type definition " << argDefinedType
775  << " at argument " << index;
776  }
777  if (getValue(operands[1])) {
778  return emitError(unknownLoc, "duplicate definition of result <id> ")
779  << operands[1];
780  }
781 
782  IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
783  if (!inputIndexAttr) {
784  return emitError(unknownLoc,
785  "unable to read inputIndex value from constant op ")
786  << operands[2];
787  }
788  BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
789  valueMap[operands[1]] = argValue;
790  }
791 
792  graphOutputs.resize(graphType.getNumResults());
793 
794  // RAII guard to reset the insertion point to the module's region after
795  // deserializing the body of this function.
796  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
797 
798  blockMap[graphID] = entryBlock;
799  if (failed(createGraphBlock(graphID))) {
800  return failure();
801  }
802 
803  // Process all the instructions in the graph until and including
804  // OpGraphEndARM.
805  spirv::Opcode opcode;
806  ArrayRef<uint32_t> instOperands;
807  do {
808  if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
809  return failure();
810  }
811 
812  if (failed(processInstruction(opcode, instOperands))) {
813  return failure();
814  }
815  } while (opcode != spirv::Opcode::OpGraphEndARM);
816 
817  return success();
818 }
819 
820 LogicalResult
821 spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
822  if (operands.size() != 2) {
823  return emitError(
824  unknownLoc,
825  "expected value id and output index for OpGraphSetOutputARM");
826  }
827 
828  uint32_t id = operands[0];
829  Value value = getValue(id);
830  if (!value) {
831  return emitError(unknownLoc, "could not find result <id> ") << id;
832  }
833 
834  IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
835  if (!outputIndexAttr) {
836  return emitError(unknownLoc,
837  "unable to read outputIndex value from constant op ")
838  << operands[1];
839  }
840  graphOutputs[outputIndexAttr.getInt()] = value;
841  return success();
842 }
843 
844 LogicalResult
845 spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
846  // Create GraphOutputsARM instruction.
847  opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
848 
849  // Process OpGraphEndARM.
850  if (!operands.empty()) {
851  return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
852  }
853 
854  curBlock = nullptr;
855  curGraph = std::nullopt;
856  graphOutputs.clear();
857 
858  LLVM_DEBUG({
859  logger.unindent();
860  logger.startLine()
861  << "//===-------------------------------------------===//\n";
862  });
863  return success();
864 }
865 
866 std::optional<std::pair<Attribute, Type>>
867 spirv::Deserializer::getConstant(uint32_t id) {
868  auto constIt = constantMap.find(id);
869  if (constIt == constantMap.end())
870  return std::nullopt;
871  return constIt->getSecond();
872 }
873 
874 std::optional<std::pair<Attribute, Type>>
875 spirv::Deserializer::getConstantCompositeReplicate(uint32_t id) {
876  if (auto it = constantCompositeReplicateMap.find(id);
877  it != constantCompositeReplicateMap.end())
878  return it->second;
879  return std::nullopt;
880 }
881 
882 std::optional<spirv::SpecConstOperationMaterializationInfo>
883 spirv::Deserializer::getSpecConstantOperation(uint32_t id) {
884  auto constIt = specConstOperationMap.find(id);
885  if (constIt == specConstOperationMap.end())
886  return std::nullopt;
887  return constIt->getSecond();
888 }
889 
890 std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
891  auto funcName = nameMap.lookup(id).str();
892  if (funcName.empty()) {
893  funcName = "spirv_fn_" + std::to_string(id);
894  }
895  return funcName;
896 }
897 
898 std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
899  std::string graphName = nameMap.lookup(id).str();
900  if (graphName.empty()) {
901  graphName = "spirv_graph_" + std::to_string(id);
902  }
903  return graphName;
904 }
905 
906 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
907  auto constName = nameMap.lookup(id).str();
908  if (constName.empty()) {
909  constName = "spirv_spec_const_" + std::to_string(id);
910  }
911  return constName;
912 }
913 
914 spirv::SpecConstantOp
915 spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
916  TypedAttr defaultValue) {
917  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
918  auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
919  defaultValue);
920  if (decorations.count(resultID)) {
921  for (auto attr : decorations[resultID].getAttrs())
922  op->setAttr(attr.getName(), attr.getValue());
923  }
924  specConstMap[resultID] = op;
925  return op;
926 }
927 
928 std::optional<spirv::GraphConstantARMOpMaterializationInfo>
929 spirv::Deserializer::getGraphConstantARM(uint32_t id) {
930  auto graphConstIt = graphConstantMap.find(id);
931  if (graphConstIt == graphConstantMap.end())
932  return std::nullopt;
933  return graphConstIt->getSecond();
934 }
935 
936 LogicalResult
937 spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
938  unsigned wordIndex = 0;
939  if (operands.size() < 3) {
940  return emitError(
941  unknownLoc,
942  "OpVariable needs at least 3 operands, type, <id> and storage class");
943  }
944 
945  // Result Type.
946  auto type = getType(operands[wordIndex]);
947  if (!type) {
948  return emitError(unknownLoc, "unknown result type <id> : ")
949  << operands[wordIndex];
950  }
951  auto ptrType = dyn_cast<spirv::PointerType>(type);
952  if (!ptrType) {
953  return emitError(unknownLoc,
954  "expected a result type <id> to be a spirv.ptr, found : ")
955  << type;
956  }
957  wordIndex++;
958 
959  // Result <id>.
960  auto variableID = operands[wordIndex];
961  auto variableName = nameMap.lookup(variableID).str();
962  if (variableName.empty()) {
963  variableName = "spirv_var_" + std::to_string(variableID);
964  }
965  wordIndex++;
966 
967  // Storage class.
968  auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
969  if (ptrType.getStorageClass() != storageClass) {
970  return emitError(unknownLoc, "mismatch in storage class of pointer type ")
971  << type << " and that specified in OpVariable instruction : "
972  << stringifyStorageClass(storageClass);
973  }
974  wordIndex++;
975 
976  // Initializer.
977  FlatSymbolRefAttr initializer = nullptr;
978 
979  if (wordIndex < operands.size()) {
980  Operation *op = nullptr;
981 
982  if (auto initOp = getGlobalVariable(operands[wordIndex]))
983  op = initOp;
984  else if (auto initOp = getSpecConstant(operands[wordIndex]))
985  op = initOp;
986  else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
987  op = initOp;
988  else
989  return emitError(unknownLoc, "unknown <id> ")
990  << operands[wordIndex] << "used as initializer";
991 
992  initializer = SymbolRefAttr::get(op);
993  wordIndex++;
994  }
995  if (wordIndex != operands.size()) {
996  return emitError(unknownLoc,
997  "found more operands than expected when deserializing "
998  "OpVariable instruction, only ")
999  << wordIndex << " of " << operands.size() << " processed";
1000  }
1001  auto loc = createFileLineColLoc(opBuilder);
1002  auto varOp = spirv::GlobalVariableOp::create(
1003  opBuilder, loc, TypeAttr::get(type),
1004  opBuilder.getStringAttr(variableName), initializer);
1005 
1006  // Decorations.
1007  if (decorations.count(variableID)) {
1008  for (auto attr : decorations[variableID].getAttrs())
1009  varOp->setAttr(attr.getName(), attr.getValue());
1010  }
1011  globalVariableMap[variableID] = varOp;
1012  return success();
1013 }
1014 
1015 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1016  auto constInfo = getConstant(id);
1017  if (!constInfo) {
1018  return nullptr;
1019  }
1020  return dyn_cast<IntegerAttr>(constInfo->first);
1021 }
1022 
1023 LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1024  if (operands.size() < 2) {
1025  return emitError(unknownLoc, "OpName needs at least 2 operands");
1026  }
1027  if (!nameMap.lookup(operands[0]).empty()) {
1028  return emitError(unknownLoc, "duplicate name found for result <id> ")
1029  << operands[0];
1030  }
1031  unsigned wordIndex = 1;
1032  StringRef name = decodeStringLiteral(operands, wordIndex);
1033  if (wordIndex != operands.size()) {
1034  return emitError(unknownLoc,
1035  "unexpected trailing words in OpName instruction");
1036  }
1037  nameMap[operands[0]] = name;
1038  return success();
1039 }
1040 
1041 //===----------------------------------------------------------------------===//
1042 // Type
1043 //===----------------------------------------------------------------------===//
1044 
1045 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1046  ArrayRef<uint32_t> operands) {
1047  if (operands.empty()) {
1048  return emitError(unknownLoc, "type instruction with opcode ")
1049  << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1050  }
1051 
1052  /// TODO: Types might be forward declared in some instructions and need to be
1053  /// handled appropriately.
1054  if (typeMap.count(operands[0])) {
1055  return emitError(unknownLoc, "duplicate definition for result <id> ")
1056  << operands[0];
1057  }
1058 
1059  switch (opcode) {
1060  case spirv::Opcode::OpTypeVoid:
1061  if (operands.size() != 1)
1062  return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1063  typeMap[operands[0]] = opBuilder.getNoneType();
1064  break;
1065  case spirv::Opcode::OpTypeBool:
1066  if (operands.size() != 1)
1067  return emitError(unknownLoc, "OpTypeBool must have no parameters");
1068  typeMap[operands[0]] = opBuilder.getI1Type();
1069  break;
1070  case spirv::Opcode::OpTypeInt: {
1071  if (operands.size() != 3)
1072  return emitError(
1073  unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1074 
1075  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1076  // to preserve or validate.
1077  // 0 indicates unsigned, or no signedness semantics
1078  // 1 indicates signed semantics."
1079  //
1080  // So we cannot differentiate signless and unsigned integers; always use
1081  // signless semantics for such cases.
1082  auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1083  : IntegerType::SignednessSemantics::Signless;
1084  typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1085  } break;
1086  case spirv::Opcode::OpTypeFloat: {
1087  if (operands.size() != 2 && operands.size() != 3)
1088  return emitError(unknownLoc,
1089  "OpTypeFloat expects either 2 operands (type, bitwidth) "
1090  "or 3 operands (type, bitwidth, encoding), but got ")
1091  << operands.size();
1092  uint32_t bitWidth = operands[1];
1093 
1094  Type floatTy;
1095  switch (bitWidth) {
1096  case 16:
1097  floatTy = opBuilder.getF16Type();
1098  break;
1099  case 32:
1100  floatTy = opBuilder.getF32Type();
1101  break;
1102  case 64:
1103  floatTy = opBuilder.getF64Type();
1104  break;
1105  default:
1106  return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1107  << bitWidth;
1108  }
1109 
1110  if (operands.size() == 3) {
1111  if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1112  return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1113  << operands[2];
1114  if (bitWidth != 16)
1115  return emitError(unknownLoc,
1116  "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1117  << bitWidth << " (expected 16)";
1118  floatTy = opBuilder.getBF16Type();
1119  }
1120 
1121  typeMap[operands[0]] = floatTy;
1122  } break;
1123  case spirv::Opcode::OpTypeVector: {
1124  if (operands.size() != 3) {
1125  return emitError(
1126  unknownLoc,
1127  "OpTypeVector must have element type and count parameters");
1128  }
1129  Type elementTy = getType(operands[1]);
1130  if (!elementTy) {
1131  return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1132  << operands[1];
1133  }
1134  typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1135  } break;
1136  case spirv::Opcode::OpTypePointer: {
1137  return processOpTypePointer(operands);
1138  } break;
1139  case spirv::Opcode::OpTypeArray:
1140  return processArrayType(operands);
1141  case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1142  return processCooperativeMatrixTypeKHR(operands);
1143  case spirv::Opcode::OpTypeFunction:
1144  return processFunctionType(operands);
1145  case spirv::Opcode::OpTypeImage:
1146  return processImageType(operands);
1147  case spirv::Opcode::OpTypeSampledImage:
1148  return processSampledImageType(operands);
1149  case spirv::Opcode::OpTypeRuntimeArray:
1150  return processRuntimeArrayType(operands);
1151  case spirv::Opcode::OpTypeStruct:
1152  return processStructType(operands);
1153  case spirv::Opcode::OpTypeMatrix:
1154  return processMatrixType(operands);
1155  case spirv::Opcode::OpTypeTensorARM:
1156  return processTensorARMType(operands);
1157  case spirv::Opcode::OpTypeGraphARM:
1158  return processGraphTypeARM(operands);
1159  default:
1160  return emitError(unknownLoc, "unhandled type instruction");
1161  }
1162  return success();
1163 }
1164 
1165 LogicalResult
1166 spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
1167  if (operands.size() != 3)
1168  return emitError(unknownLoc, "OpTypePointer must have two parameters");
1169 
1170  auto pointeeType = getType(operands[2]);
1171  if (!pointeeType)
1172  return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1173  << operands[2];
1174 
1175  uint32_t typePointerID = operands[0];
1176  auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1177  typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1178 
1179  for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1180  deferredStructIt != std::end(deferredStructTypesInfos);) {
1181  for (auto *unresolvedMemberIt =
1182  std::begin(deferredStructIt->unresolvedMemberTypes);
1183  unresolvedMemberIt !=
1184  std::end(deferredStructIt->unresolvedMemberTypes);) {
1185  if (unresolvedMemberIt->first == typePointerID) {
1186  // The newly constructed pointer type can resolve one of the
1187  // deferred struct type members; update the memberTypes list and
1188  // clean the unresolvedMemberTypes list accordingly.
1189  deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1190  typeMap[typePointerID];
1191  unresolvedMemberIt =
1192  deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1193  } else {
1194  ++unresolvedMemberIt;
1195  }
1196  }
1197 
1198  if (deferredStructIt->unresolvedMemberTypes.empty()) {
1199  // All deferred struct type members are now resolved, set the struct body.
1200  auto structType = deferredStructIt->deferredStructType;
1201 
1202  assert(structType && "expected a spirv::StructType");
1203  assert(structType.isIdentified() && "expected an indentified struct");
1204 
1205  if (failed(structType.trySetBody(
1206  deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1207  deferredStructIt->memberDecorationsInfo,
1208  deferredStructIt->structDecorationsInfo)))
1209  return failure();
1210 
1211  deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1212  } else {
1213  ++deferredStructIt;
1214  }
1215  }
1216 
1217  return success();
1218 }
1219 
1220 LogicalResult
1221 spirv::Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
1222  if (operands.size() != 3) {
1223  return emitError(unknownLoc,
1224  "OpTypeArray must have element type and count parameters");
1225  }
1226 
1227  Type elementTy = getType(operands[1]);
1228  if (!elementTy) {
1229  return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1230  << operands[1];
1231  }
1232 
1233  unsigned count = 0;
1234  // TODO: The count can also come frome a specialization constant.
1235  auto countInfo = getConstant(operands[2]);
1236  if (!countInfo) {
1237  return emitError(unknownLoc, "OpTypeArray count <id> ")
1238  << operands[2] << "can only come from normal constant right now";
1239  }
1240 
1241  if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1242  count = intVal.getValue().getZExtValue();
1243  } else {
1244  return emitError(unknownLoc, "OpTypeArray count must come from a "
1245  "scalar integer constant instruction");
1246  }
1247 
1248  typeMap[operands[0]] = spirv::ArrayType::get(
1249  elementTy, count, typeDecorations.lookup(operands[0]));
1250  return success();
1251 }
1252 
1253 LogicalResult
1254 spirv::Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
1255  assert(!operands.empty() && "No operands for processing function type");
1256  if (operands.size() == 1) {
1257  return emitError(unknownLoc, "missing return type for OpTypeFunction");
1258  }
1259  auto returnType = getType(operands[1]);
1260  if (!returnType) {
1261  return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1262  }
1263  SmallVector<Type, 1> argTypes;
1264  for (size_t i = 2, e = operands.size(); i < e; ++i) {
1265  auto ty = getType(operands[i]);
1266  if (!ty) {
1267  return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1268  }
1269  argTypes.push_back(ty);
1270  }
1271  ArrayRef<Type> returnTypes;
1272  if (!isVoidType(returnType)) {
1273  returnTypes = llvm::ArrayRef(returnType);
1274  }
1275  typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1276  return success();
1277 }
1278 
1279 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1280  ArrayRef<uint32_t> operands) {
1281  if (operands.size() != 6) {
1282  return emitError(unknownLoc,
1283  "OpTypeCooperativeMatrixKHR must have element type, "
1284  "scope, row and column parameters, and use");
1285  }
1286 
1287  Type elementTy = getType(operands[1]);
1288  if (!elementTy) {
1289  return emitError(unknownLoc,
1290  "OpTypeCooperativeMatrixKHR references undefined <id> ")
1291  << operands[1];
1292  }
1293 
1294  std::optional<spirv::Scope> scope =
1295  spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1296  if (!scope) {
1297  return emitError(
1298  unknownLoc,
1299  "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1300  << operands[2];
1301  }
1302 
1303  IntegerAttr rowsAttr = getConstantInt(operands[3]);
1304  IntegerAttr columnsAttr = getConstantInt(operands[4]);
1305  IntegerAttr useAttr = getConstantInt(operands[5]);
1306 
1307  if (!rowsAttr)
1308  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1309  "undefined constant <id> ")
1310  << operands[3];
1311 
1312  if (!columnsAttr)
1313  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1314  "references undefined constant <id> ")
1315  << operands[4];
1316 
1317  if (!useAttr)
1318  return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1319  "undefined constant <id> ")
1320  << operands[5];
1321 
1322  unsigned rows = rowsAttr.getInt();
1323  unsigned columns = columnsAttr.getInt();
1324 
1325  std::optional<spirv::CooperativeMatrixUseKHR> use =
1326  spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1327  if (!use) {
1328  return emitError(
1329  unknownLoc,
1330  "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1331  << operands[5];
1332  }
1333 
1334  typeMap[operands[0]] =
1335  spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1336  return success();
1337 }
1338 
1339 LogicalResult
1340 spirv::Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
1341  if (operands.size() != 2) {
1342  return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1343  }
1344  Type memberType = getType(operands[1]);
1345  if (!memberType) {
1346  return emitError(unknownLoc,
1347  "OpTypeRuntimeArray references undefined <id> ")
1348  << operands[1];
1349  }
1350  typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1351  memberType, typeDecorations.lookup(operands[0]));
1352  return success();
1353 }
1354 
1355 LogicalResult
1356 spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
1357  // TODO: Find a way to handle identified structs when debug info is stripped.
1358 
1359  if (operands.empty()) {
1360  return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1361  }
1362 
1363  if (operands.size() == 1) {
1364  // Handle empty struct.
1365  typeMap[operands[0]] =
1366  spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1367  return success();
1368  }
1369 
1370  // First element is operand ID, second element is member index in the struct.
1371  SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1372  SmallVector<Type, 4> memberTypes;
1373 
1374  for (auto op : llvm::drop_begin(operands, 1)) {
1375  Type memberType = getType(op);
1376  bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1377 
1378  if (!memberType && !typeForwardPtr)
1379  return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1380  << op;
1381 
1382  if (!memberType)
1383  unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1384 
1385  memberTypes.push_back(memberType);
1386  }
1387 
1390  if (memberDecorationMap.count(operands[0])) {
1391  auto &allMemberDecorations = memberDecorationMap[operands[0]];
1392  for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1393  if (allMemberDecorations.count(memberIndex)) {
1394  for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1395  // Check for offset.
1396  if (memberDecoration.first == spirv::Decoration::Offset) {
1397  // If offset info is empty, resize to the number of members;
1398  if (offsetInfo.empty()) {
1399  offsetInfo.resize(memberTypes.size());
1400  }
1401  offsetInfo[memberIndex] = memberDecoration.second[0];
1402  } else {
1403  auto intType = mlir::IntegerType::get(context, 32);
1404  if (!memberDecoration.second.empty()) {
1405  memberDecorationsInfo.emplace_back(
1406  memberIndex, memberDecoration.first,
1407  IntegerAttr::get(intType, memberDecoration.second[0]));
1408  } else {
1409  memberDecorationsInfo.emplace_back(
1410  memberIndex, memberDecoration.first, UnitAttr::get(context));
1411  }
1412  }
1413  }
1414  }
1415  }
1416  }
1417 
1419  if (decorations.count(operands[0])) {
1420  NamedAttrList &allDecorations = decorations[operands[0]];
1421  for (NamedAttribute &decorationAttr : allDecorations) {
1422  std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1423  llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1424  assert(decoration.has_value());
1425  structDecorationsInfo.emplace_back(decoration.value(),
1426  decorationAttr.getValue());
1427  }
1428  }
1429 
1430  uint32_t structID = operands[0];
1431  std::string structIdentifier = nameMap.lookup(structID).str();
1432 
1433  if (structIdentifier.empty()) {
1434  assert(unresolvedMemberTypes.empty() &&
1435  "didn't expect unresolved member types");
1436  typeMap[structID] = spirv::StructType::get(
1437  memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1438  } else {
1439  auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1440  typeMap[structID] = structTy;
1441 
1442  if (!unresolvedMemberTypes.empty())
1443  deferredStructTypesInfos.push_back(
1444  {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1445  memberDecorationsInfo, structDecorationsInfo});
1446  else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1447  memberDecorationsInfo,
1448  structDecorationsInfo)))
1449  return failure();
1450  }
1451 
1452  // TODO: Update StructType to have member name as attribute as
1453  // well.
1454  return success();
1455 }
1456 
1457 LogicalResult
1458 spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) {
1459  if (operands.size() != 3) {
1460  // Three operands are needed: result_id, column_type, and column_count
1461  return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1462  " (result_id, column_type, and column_count)");
1463  }
1464  // Matrix columns must be of vector type
1465  Type elementTy = getType(operands[1]);
1466  if (!elementTy) {
1467  return emitError(unknownLoc,
1468  "OpTypeMatrix references undefined column type.")
1469  << operands[1];
1470  }
1471 
1472  uint32_t colsCount = operands[2];
1473  typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1474  return success();
1475 }
1476 
1477 LogicalResult
1478 spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
1479  unsigned size = operands.size();
1480  if (size < 2 || size > 4)
1481  return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1482  "(result_id, element_type, (rank), (shape)) ")
1483  << size;
1484 
1485  Type elementTy = getType(operands[1]);
1486  if (!elementTy)
1487  return emitError(unknownLoc,
1488  "OpTypeTensorARM references undefined element type ")
1489  << operands[1];
1490 
1491  if (size == 2) {
1492  typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1493  return success();
1494  }
1495 
1496  IntegerAttr rankAttr = getConstantInt(operands[2]);
1497  if (!rankAttr)
1498  return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1499  "scalar integer constant instruction");
1500  unsigned rank = rankAttr.getValue().getZExtValue();
1501  if (size == 3) {
1502  SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1503  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1504  return success();
1505  }
1506 
1507  std::optional<std::pair<Attribute, Type>> shapeInfo =
1508  getConstant(operands[3]);
1509  if (!shapeInfo)
1510  return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1511  "constant instruction of type OpTypeArray");
1512 
1513  ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1515  for (auto dimAttr : shapeArrayAttr.getValue()) {
1516  auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1517  if (!dimIntAttr)
1518  return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1519  "dimension size");
1520  shape.push_back(dimIntAttr.getValue().getSExtValue());
1521  }
1522  typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1523  return success();
1524 }
1525 
1526 LogicalResult
1527 spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
1528  unsigned size = operands.size();
1529  if (size < 2) {
1530  return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1531  "(result_id, num_inputs, (inout0_type, "
1532  "inout1_type, ...))")
1533  << size;
1534  }
1535  uint32_t numInputs = operands[1];
1536  SmallVector<Type, 1> argTypes;
1537  SmallVector<Type, 1> returnTypes;
1538  for (unsigned i = 2; i < size; ++i) {
1539  Type inOutTy = getType(operands[i]);
1540  if (!inOutTy) {
1541  return emitError(unknownLoc,
1542  "OpTypeGraphARM references undefined element type.")
1543  << operands[i];
1544  }
1545  if (i - 2 >= numInputs) {
1546  returnTypes.push_back(inOutTy);
1547  } else {
1548  argTypes.push_back(inOutTy);
1549  }
1550  }
1551  typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1552  return success();
1553 }
1554 
1555 LogicalResult
1556 spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
1557  if (operands.size() != 2)
1558  return emitError(unknownLoc,
1559  "OpTypeForwardPointer instruction must have two operands");
1560 
1561  typeForwardPointerIDs.insert(operands[0]);
1562  // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1563  // instruction that defines the actual type.
1564 
1565  return success();
1566 }
1567 
1568 LogicalResult
1569 spirv::Deserializer::processImageType(ArrayRef<uint32_t> operands) {
1570  // TODO: Add support for Access Qualifier.
1571  if (operands.size() != 8)
1572  return emitError(
1573  unknownLoc,
1574  "OpTypeImage with non-eight operands are not supported yet");
1575 
1576  Type elementTy = getType(operands[1]);
1577  if (!elementTy)
1578  return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1579  << operands[1];
1580 
1581  auto dim = spirv::symbolizeDim(operands[2]);
1582  if (!dim)
1583  return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1584  << operands[2];
1585 
1586  auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1587  if (!depthInfo)
1588  return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1589  << operands[3];
1590 
1591  auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1592  if (!arrayedInfo)
1593  return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1594  << operands[4];
1595 
1596  auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1597  if (!samplingInfo)
1598  return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1599 
1600  auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1601  if (!samplerUseInfo)
1602  return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1603  << operands[6];
1604 
1605  auto format = spirv::symbolizeImageFormat(operands[7]);
1606  if (!format)
1607  return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1608  << operands[7];
1609 
1610  typeMap[operands[0]] = spirv::ImageType::get(
1611  elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1612  samplingInfo.value(), samplerUseInfo.value(), format.value());
1613  return success();
1614 }
1615 
1616 LogicalResult
1617 spirv::Deserializer::processSampledImageType(ArrayRef<uint32_t> operands) {
1618  if (operands.size() != 2)
1619  return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1620 
1621  Type elementTy = getType(operands[1]);
1622  if (!elementTy)
1623  return emitError(unknownLoc,
1624  "OpTypeSampledImage references undefined <id>: ")
1625  << operands[1];
1626 
1627  typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1628  return success();
1629 }
1630 
1631 //===----------------------------------------------------------------------===//
1632 // Constant
1633 //===----------------------------------------------------------------------===//
1634 
1635 LogicalResult spirv::Deserializer::processConstant(ArrayRef<uint32_t> operands,
1636  bool isSpec) {
1637  StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1638 
1639  if (operands.size() < 2) {
1640  return emitError(unknownLoc)
1641  << opname << " must have type <id> and result <id>";
1642  }
1643  if (operands.size() < 3) {
1644  return emitError(unknownLoc)
1645  << opname << " must have at least 1 more parameter";
1646  }
1647 
1648  Type resultType = getType(operands[0]);
1649  if (!resultType) {
1650  return emitError(unknownLoc, "undefined result type from <id> ")
1651  << operands[0];
1652  }
1653 
1654  auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1655  if (bitwidth == 64) {
1656  if (operands.size() == 4) {
1657  return success();
1658  }
1659  return emitError(unknownLoc)
1660  << opname << " should have 2 parameters for 64-bit values";
1661  }
1662  if (bitwidth <= 32) {
1663  if (operands.size() == 3) {
1664  return success();
1665  }
1666 
1667  return emitError(unknownLoc)
1668  << opname
1669  << " should have 1 parameter for values with no more than 32 bits";
1670  }
1671  return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1672  << bitwidth;
1673  };
1674 
1675  auto resultID = operands[1];
1676 
1677  if (auto intType = dyn_cast<IntegerType>(resultType)) {
1678  auto bitwidth = intType.getWidth();
1679  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1680  return failure();
1681  }
1682 
1683  APInt value;
1684  if (bitwidth == 64) {
1685  // 64-bit integers are represented with two SPIR-V words. According to
1686  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1687  // literal’s low-order words appear first."
1688  struct DoubleWord {
1689  uint32_t word1;
1690  uint32_t word2;
1691  } words = {operands[2], operands[3]};
1692  value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1693  } else if (bitwidth <= 32) {
1694  value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1695  /*implicitTrunc=*/true);
1696  }
1697 
1698  auto attr = opBuilder.getIntegerAttr(intType, value);
1699 
1700  if (isSpec) {
1701  createSpecConstant(unknownLoc, resultID, attr);
1702  } else {
1703  // For normal constants, we just record the attribute (and its type) for
1704  // later materialization at use sites.
1705  constantMap.try_emplace(resultID, attr, intType);
1706  }
1707 
1708  return success();
1709  }
1710 
1711  if (auto floatType = dyn_cast<FloatType>(resultType)) {
1712  auto bitwidth = floatType.getWidth();
1713  if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1714  return failure();
1715  }
1716 
1717  APFloat value(0.f);
1718  if (floatType.isF64()) {
1719  // Double values are represented with two SPIR-V words. According to
1720  // SPIR-V spec: "When the type’s bit width is larger than one word, the
1721  // literal’s low-order words appear first."
1722  struct DoubleWord {
1723  uint32_t word1;
1724  uint32_t word2;
1725  } words = {operands[2], operands[3]};
1726  value = APFloat(llvm::bit_cast<double>(words));
1727  } else if (floatType.isF32()) {
1728  value = APFloat(llvm::bit_cast<float>(operands[2]));
1729  } else if (floatType.isF16()) {
1730  APInt data(16, operands[2]);
1731  value = APFloat(APFloat::IEEEhalf(), data);
1732  } else if (floatType.isBF16()) {
1733  APInt data(16, operands[2]);
1734  value = APFloat(APFloat::BFloat(), data);
1735  }
1736 
1737  auto attr = opBuilder.getFloatAttr(floatType, value);
1738  if (isSpec) {
1739  createSpecConstant(unknownLoc, resultID, attr);
1740  } else {
1741  // For normal constants, we just record the attribute (and its type) for
1742  // later materialization at use sites.
1743  constantMap.try_emplace(resultID, attr, floatType);
1744  }
1745 
1746  return success();
1747  }
1748 
1749  return emitError(unknownLoc, "OpConstant can only generate values of "
1750  "scalar integer or floating-point type");
1751 }
1752 
1753 LogicalResult spirv::Deserializer::processConstantBool(
1754  bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1755  if (operands.size() != 2) {
1756  return emitError(unknownLoc, "Op")
1757  << (isSpec ? "Spec" : "") << "Constant"
1758  << (isTrue ? "True" : "False")
1759  << " must have type <id> and result <id>";
1760  }
1761 
1762  auto attr = opBuilder.getBoolAttr(isTrue);
1763  auto resultID = operands[1];
1764  if (isSpec) {
1765  createSpecConstant(unknownLoc, resultID, attr);
1766  } else {
1767  // For normal constants, we just record the attribute (and its type) for
1768  // later materialization at use sites.
1769  constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1770  }
1771 
1772  return success();
1773 }
1774 
1775 LogicalResult
1776 spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
1777  if (operands.size() < 2) {
1778  return emitError(unknownLoc,
1779  "OpConstantComposite must have type <id> and result <id>");
1780  }
1781  if (operands.size() < 3) {
1782  return emitError(unknownLoc,
1783  "OpConstantComposite must have at least 1 parameter");
1784  }
1785 
1786  Type resultType = getType(operands[0]);
1787  if (!resultType) {
1788  return emitError(unknownLoc, "undefined result type from <id> ")
1789  << operands[0];
1790  }
1791 
1792  SmallVector<Attribute, 4> elements;
1793  elements.reserve(operands.size() - 2);
1794  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1795  auto elementInfo = getConstant(operands[i]);
1796  if (!elementInfo) {
1797  return emitError(unknownLoc, "OpConstantComposite component <id> ")
1798  << operands[i] << " must come from a normal constant";
1799  }
1800  elements.push_back(elementInfo->first);
1801  }
1802 
1803  auto resultID = operands[1];
1804  if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1805  SmallVector<Attribute> flattenedElems;
1806  for (Attribute element : elements) {
1807  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1808  for (auto value : denseElemAttr.getValues<Attribute>())
1809  flattenedElems.push_back(value);
1810  } else {
1811  flattenedElems.push_back(element);
1812  }
1813  }
1814  auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1815  constantMap.try_emplace(resultID, attr, tensorType);
1816  } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1817  auto attr = DenseElementsAttr::get(shapedType, elements);
1818  // For normal constants, we just record the attribute (and its type) for
1819  // later materialization at use sites.
1820  constantMap.try_emplace(resultID, attr, shapedType);
1821  } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1822  auto attr = opBuilder.getArrayAttr(elements);
1823  constantMap.try_emplace(resultID, attr, resultType);
1824  } else {
1825  return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1826  << resultType;
1827  }
1828 
1829  return success();
1830 }
1831 
1832 LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
1833  ArrayRef<uint32_t> operands) {
1834  if (operands.size() != 3) {
1835  return emitError(
1836  unknownLoc,
1837  "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1838  << operands.size();
1839  }
1840 
1841  Type resultType = getType(operands[0]);
1842  if (!resultType) {
1843  return emitError(unknownLoc, "undefined result type from <id> ")
1844  << operands[0];
1845  }
1846 
1847  auto compositeType = dyn_cast<CompositeType>(resultType);
1848  if (!compositeType) {
1849  return emitError(unknownLoc,
1850  "result type from <id> is not a composite type")
1851  << operands[0];
1852  }
1853 
1854  uint32_t resultID = operands[1];
1855  uint32_t constantID = operands[2];
1856 
1857  std::optional<std::pair<Attribute, Type>> constantInfo =
1858  getConstant(constantID);
1859  if (constantInfo.has_value()) {
1860  constantCompositeReplicateMap.try_emplace(
1861  resultID, constantInfo.value().first, resultType);
1862  return success();
1863  }
1864 
1865  std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1866  getConstantCompositeReplicate(constantID);
1867  if (replicatedConstantCompositeInfo.has_value()) {
1868  constantCompositeReplicateMap.try_emplace(
1869  resultID, replicatedConstantCompositeInfo.value().first, resultType);
1870  return success();
1871  }
1872 
1873  return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1874  << constantID
1875  << " must come from a normal constant or a "
1876  "OpConstantCompositeReplicateEXT";
1877 }
1878 
1879 LogicalResult
1880 spirv::Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
1881  if (operands.size() < 2) {
1882  return emitError(
1883  unknownLoc,
1884  "OpSpecConstantComposite must have type <id> and result <id>");
1885  }
1886  if (operands.size() < 3) {
1887  return emitError(unknownLoc,
1888  "OpSpecConstantComposite must have at least 1 parameter");
1889  }
1890 
1891  Type resultType = getType(operands[0]);
1892  if (!resultType) {
1893  return emitError(unknownLoc, "undefined result type from <id> ")
1894  << operands[0];
1895  }
1896 
1897  auto resultID = operands[1];
1898  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1899 
1900  SmallVector<Attribute, 4> elements;
1901  elements.reserve(operands.size() - 2);
1902  for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1903  auto elementInfo = getSpecConstant(operands[i]);
1904  elements.push_back(SymbolRefAttr::get(elementInfo));
1905  }
1906 
1907  auto op = spirv::SpecConstantCompositeOp::create(
1908  opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1909  opBuilder.getArrayAttr(elements));
1910  specConstCompositeMap[resultID] = op;
1911 
1912  return success();
1913 }
1914 
1915 LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
1916  ArrayRef<uint32_t> operands) {
1917  if (operands.size() != 3) {
1918  return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1919  "3 operands but found ")
1920  << operands.size();
1921  }
1922 
1923  Type resultType = getType(operands[0]);
1924  if (!resultType) {
1925  return emitError(unknownLoc, "undefined result type from <id> ")
1926  << operands[0];
1927  }
1928 
1929  auto compositeType = dyn_cast<CompositeType>(resultType);
1930  if (!compositeType) {
1931  return emitError(unknownLoc,
1932  "result type from <id> is not a composite type")
1933  << operands[0];
1934  }
1935 
1936  uint32_t resultID = operands[1];
1937 
1938  auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1939  spirv::SpecConstantOp constituentSpecConstantOp =
1940  getSpecConstant(operands[2]);
1941  auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1942  opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1943  SymbolRefAttr::get(constituentSpecConstantOp));
1944 
1945  specConstCompositeReplicateMap[resultID] = op;
1946 
1947  return success();
1948 }
1949 
1950 LogicalResult
1951 spirv::Deserializer::processSpecConstantOperation(ArrayRef<uint32_t> operands) {
1952  if (operands.size() < 3)
1953  return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1954  "result <id>, and operand opcode");
1955 
1956  uint32_t resultTypeID = operands[0];
1957 
1958  if (!getType(resultTypeID))
1959  return emitError(unknownLoc, "undefined result type from <id> ")
1960  << resultTypeID;
1961 
1962  uint32_t resultID = operands[1];
1963  spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1964  auto emplaceResult = specConstOperationMap.try_emplace(
1965  resultID,
1966  SpecConstOperationMaterializationInfo{
1967  enclosedOpcode, resultTypeID,
1968  SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1969 
1970  if (!emplaceResult.second)
1971  return emitError(unknownLoc, "value with <id>: ")
1972  << resultID << " is probably defined before.";
1973 
1974  return success();
1975 }
1976 
1977 Value spirv::Deserializer::materializeSpecConstantOperation(
1978  uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1979  ArrayRef<uint32_t> enclosedOpOperands) {
1980 
1981  Type resultType = getType(resultTypeID);
1982 
1983  // Instructions wrapped by OpSpecConstantOp need an ID for their
1984  // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1985  // dialect wrapped op. For that purpose, a new value map is created and "fake"
1986  // ID in that map is assigned to the result of the enclosed instruction. Note
1987  // that there is no need to update this fake ID since we only need to
1988  // reference the created Value for the enclosed op from the spv::YieldOp
1989  // created later in this method (both of which are the only values in their
1990  // region: the SpecConstantOperation's region). If we encounter another
1991  // SpecConstantOperation in the module, we simply re-use the fake ID since the
1992  // previous Value assigned to it isn't visible in the current scope anyway.
1993  DenseMap<uint32_t, Value> newValueMap;
1994  llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1995  constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1996 
1997  SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
1998  enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1999  enclosedOpResultTypeAndOperands.push_back(fakeID);
2000  enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2001  enclosedOpOperands.end());
2002 
2003  // Process enclosed instruction before creating the enclosing
2004  // specConstantOperation (and its region). This way, references to constants,
2005  // global variables, and spec constants will be materialized outside the new
2006  // op's region. For more info, see Deserializer::getValue's implementation.
2007  if (failed(
2008  processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2009  return Value();
2010 
2011  // Since the enclosed op is emitted in the current block, split it in a
2012  // separate new block.
2013  Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2014 
2015  auto loc = createFileLineColLoc(opBuilder);
2016  auto specConstOperationOp =
2017  spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2018 
2019  Region &body = specConstOperationOp.getBody();
2020  // Move the new block into SpecConstantOperation's body.
2021  body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2022  Region::iterator(enclosedBlock));
2023  Block &block = body.back();
2024 
2025  // RAII guard to reset the insertion point to the module's region after
2026  // deserializing the body of the specConstantOperation.
2027  OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2028  opBuilder.setInsertionPointToEnd(&block);
2029 
2030  spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2031  return specConstOperationOp.getResult();
2032 }
2033 
2034 LogicalResult
2035 spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
2036  if (operands.size() != 2) {
2037  return emitError(unknownLoc,
2038  "OpConstantNull must only have type <id> and result <id>");
2039  }
2040 
2041  Type resultType = getType(operands[0]);
2042  if (!resultType) {
2043  return emitError(unknownLoc, "undefined result type from <id> ")
2044  << operands[0];
2045  }
2046 
2047  auto resultID = operands[1];
2048  Attribute attr;
2049  if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2050  attr = opBuilder.getZeroAttr(resultType);
2051  } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2052  if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2053  attr = DenseElementsAttr::get(tensorType, element);
2054  }
2055 
2056  if (attr) {
2057  // For normal constants, we just record the attribute (and its type) for
2058  // later materialization at use sites.
2059  constantMap.try_emplace(resultID, attr, resultType);
2060  return success();
2061  }
2062 
2063  return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2064  << resultType;
2065 }
2066 
2067 LogicalResult
2068 spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
2069  if (operands.size() < 3) {
2070  return emitError(unknownLoc)
2071  << "OpGraphConstantARM must have at least 2 operands";
2072  }
2073 
2074  Type resultType = getType(operands[0]);
2075  if (!resultType) {
2076  return emitError(unknownLoc, "undefined result type from <id> ")
2077  << operands[0];
2078  }
2079 
2080  uint32_t resultID = operands[1];
2081 
2082  if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083  return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2084  }
2085 
2086  APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2087  Type i32Ty = opBuilder.getIntegerType(32);
2088  IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2089  graphConstantMap.try_emplace(
2090  resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2091 
2092  return success();
2093 }
2094 
2095 //===----------------------------------------------------------------------===//
2096 // Control flow
2097 //===----------------------------------------------------------------------===//
2098 
2099 Block *spirv::Deserializer::getOrCreateBlock(uint32_t id) {
2100  if (auto *block = getBlock(id)) {
2101  LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2102  << " @ " << block << "\n");
2103  return block;
2104  }
2105 
2106  // We don't know where this block will be placed finally (in a
2107  // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2108  // function for now and sort out the proper place later.
2109  auto *block = curFunction->addBlock();
2110  LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2111  << " @ " << block << "\n");
2112  return blockMap[id] = block;
2113 }
2114 
2115 LogicalResult spirv::Deserializer::processBranch(ArrayRef<uint32_t> operands) {
2116  if (!curBlock) {
2117  return emitError(unknownLoc, "OpBranch must appear inside a block");
2118  }
2119 
2120  if (operands.size() != 1) {
2121  return emitError(unknownLoc, "OpBranch must take exactly one target label");
2122  }
2123 
2124  auto *target = getOrCreateBlock(operands[0]);
2125  auto loc = createFileLineColLoc(opBuilder);
2126  // The preceding instruction for the OpBranch instruction could be an
2127  // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2128  // the same OpLine information.
2129  spirv::BranchOp::create(opBuilder, loc, target);
2130 
2131  clearDebugLine();
2132  return success();
2133 }
2134 
2135 LogicalResult
2136 spirv::Deserializer::processBranchConditional(ArrayRef<uint32_t> operands) {
2137  if (!curBlock) {
2138  return emitError(unknownLoc,
2139  "OpBranchConditional must appear inside a block");
2140  }
2141 
2142  if (operands.size() != 3 && operands.size() != 5) {
2143  return emitError(unknownLoc,
2144  "OpBranchConditional must have condition, true label, "
2145  "false label, and optionally two branch weights");
2146  }
2147 
2148  auto condition = getValue(operands[0]);
2149  auto *trueBlock = getOrCreateBlock(operands[1]);
2150  auto *falseBlock = getOrCreateBlock(operands[2]);
2151 
2152  std::optional<std::pair<uint32_t, uint32_t>> weights;
2153  if (operands.size() == 5) {
2154  weights = std::make_pair(operands[3], operands[4]);
2155  }
2156  // The preceding instruction for the OpBranchConditional instruction could be
2157  // an OpSelectionMerge instruction, in this case they will have the same
2158  // OpLine information.
2159  auto loc = createFileLineColLoc(opBuilder);
2160  spirv::BranchConditionalOp::create(
2161  opBuilder, loc, condition, trueBlock,
2162  /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2163  /*falseArguments=*/ArrayRef<Value>(), weights);
2164 
2165  clearDebugLine();
2166  return success();
2167 }
2168 
2169 LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
2170  if (!curFunction) {
2171  return emitError(unknownLoc, "OpLabel must appear inside a function");
2172  }
2173 
2174  if (operands.size() != 1) {
2175  return emitError(unknownLoc, "OpLabel should only have result <id>");
2176  }
2177 
2178  auto labelID = operands[0];
2179  // We may have forward declared this block.
2180  auto *block = getOrCreateBlock(labelID);
2181  LLVM_DEBUG(logger.startLine()
2182  << "[block] populating block " << block << "\n");
2183  // If we have seen this block, make sure it was just a forward declaration.
2184  assert(block->empty() && "re-deserialize the same block!");
2185 
2186  opBuilder.setInsertionPointToStart(block);
2187  blockMap[labelID] = curBlock = block;
2188 
2189  return success();
2190 }
2191 
2192 LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2193  if (!curGraph) {
2194  return emitError(unknownLoc, "a graph block must appear inside a graph");
2195  }
2196 
2197  // We may have forward declared this block.
2198  Block *block = getOrCreateBlock(graphID);
2199  LLVM_DEBUG(logger.startLine()
2200  << "[block] populating block " << block << "\n");
2201  // If we have seen this block, make sure it was just a forward declaration.
2202  assert(block->empty() && "re-deserialize the same block!");
2203 
2204  opBuilder.setInsertionPointToStart(block);
2205  blockMap[graphID] = curBlock = block;
2206 
2207  return success();
2208 }
2209 
2210 LogicalResult
2211 spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
2212  if (!curBlock) {
2213  return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2214  }
2215 
2216  if (operands.size() < 2) {
2217  return emitError(
2218  unknownLoc,
2219  "OpSelectionMerge must specify merge target and selection control");
2220  }
2221 
2222  auto *mergeBlock = getOrCreateBlock(operands[0]);
2223  auto loc = createFileLineColLoc(opBuilder);
2224  auto selectionControl = operands[1];
2225 
2226  if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2227  .second) {
2228  return emitError(
2229  unknownLoc,
2230  "a block cannot have more than one OpSelectionMerge instruction");
2231  }
2232 
2233  return success();
2234 }
2235 
2236 LogicalResult
2237 spirv::Deserializer::processLoopMerge(ArrayRef<uint32_t> operands) {
2238  if (!curBlock) {
2239  return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2240  }
2241 
2242  if (operands.size() < 3) {
2243  return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2244  "continue target and loop control");
2245  }
2246 
2247  auto *mergeBlock = getOrCreateBlock(operands[0]);
2248  auto *continueBlock = getOrCreateBlock(operands[1]);
2249  auto loc = createFileLineColLoc(opBuilder);
2250  uint32_t loopControl = operands[2];
2251 
2252  if (!blockMergeInfo
2253  .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2254  .second) {
2255  return emitError(
2256  unknownLoc,
2257  "a block cannot have more than one OpLoopMerge instruction");
2258  }
2259 
2260  return success();
2261 }
2262 
2263 LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
2264  if (!curBlock) {
2265  return emitError(unknownLoc, "OpPhi must appear in a block");
2266  }
2267 
2268  if (operands.size() < 4) {
2269  return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2270  "and variable-parent pairs");
2271  }
2272 
2273  // Create a block argument for this OpPhi instruction.
2274  Type blockArgType = getType(operands[0]);
2275  BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2276  valueMap[operands[1]] = blockArg;
2277  LLVM_DEBUG(logger.startLine()
2278  << "[phi] created block argument " << blockArg
2279  << " id = " << operands[1] << " of type " << blockArgType << "\n");
2280 
2281  // For each (value, predecessor) pair, insert the value to the predecessor's
2282  // blockPhiInfo entry so later we can fix the block argument there.
2283  for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2284  uint32_t value = operands[i];
2285  Block *predecessor = getOrCreateBlock(operands[i + 1]);
2286  std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2287  blockPhiInfo[predecessorTargetPair].push_back(value);
2288  LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2289  << " with arg id = " << value << "\n");
2290  }
2291 
2292  return success();
2293 }
2294 
2295 namespace {
2296 /// A class for putting all blocks in a structured selection/loop in a
2297 /// spirv.mlir.selection/spirv.mlir.loop op.
2298 class ControlFlowStructurizer {
2299 public:
2300 #ifndef NDEBUG
2301  ControlFlowStructurizer(Location loc, uint32_t control,
2302  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2303  Block *merge, Block *cont,
2304  llvm::ScopedPrinter &logger)
2305  : location(loc), control(control), blockMergeInfo(mergeInfo),
2306  headerBlock(header), mergeBlock(merge), continueBlock(cont),
2307  logger(logger) {}
2308 #else
2309  ControlFlowStructurizer(Location loc, uint32_t control,
2310  spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2311  Block *merge, Block *cont)
2312  : location(loc), control(control), blockMergeInfo(mergeInfo),
2313  headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2314 #endif
2315 
2316  /// Structurizes the loop at the given `headerBlock`.
2317  ///
2318  /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2319  /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2320  /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2321  /// method will also update `mergeInfo` by remapping all blocks inside to the
2322  /// newly cloned ones inside structured control flow op's regions.
2323  LogicalResult structurize();
2324 
2325 private:
2326  /// Creates a new spirv.mlir.selection op at the beginning of the
2327  /// `mergeBlock`.
2328  spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2329 
2330  /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2331  spirv::LoopOp createLoopOp(uint32_t loopControl);
2332 
2333  /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2334  void collectBlocksInConstruct();
2335 
2336  Location location;
2337  uint32_t control;
2338 
2339  spirv::BlockMergeInfoMap &blockMergeInfo;
2340 
2341  Block *headerBlock;
2342  Block *mergeBlock;
2343  Block *continueBlock; // nullptr for spirv.mlir.selection
2344 
2345  SetVector<Block *> constructBlocks;
2346 
2347 #ifndef NDEBUG
2348  /// A logger used to emit information during the deserialzation process.
2349  llvm::ScopedPrinter &logger;
2350 #endif
2351 };
2352 } // namespace
2353 
2354 spirv::SelectionOp
2355 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2356  // Create a builder and set the insertion point to the beginning of the
2357  // merge block so that the newly created SelectionOp will be inserted there.
2358  OpBuilder builder(&mergeBlock->front());
2359 
2360  auto control = static_cast<spirv::SelectionControl>(selectionControl);
2361  auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2362  selectionOp.addMergeBlock(builder);
2363 
2364  return selectionOp;
2365 }
2366 
2367 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2368  // Create a builder and set the insertion point to the beginning of the
2369  // merge block so that the newly created LoopOp will be inserted there.
2370  OpBuilder builder(&mergeBlock->front());
2371 
2372  auto control = static_cast<spirv::LoopControl>(loopControl);
2373  auto loopOp = spirv::LoopOp::create(builder, location, control);
2374  loopOp.addEntryAndMergeBlock(builder);
2375 
2376  return loopOp;
2377 }
2378 
2379 void ControlFlowStructurizer::collectBlocksInConstruct() {
2380  assert(constructBlocks.empty() && "expected empty constructBlocks");
2381 
2382  // Put the header block in the work list first.
2383  constructBlocks.insert(headerBlock);
2384 
2385  // For each item in the work list, add its successors excluding the merge
2386  // block.
2387  for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2388  for (auto *successor : constructBlocks[i]->getSuccessors())
2389  if (successor != mergeBlock)
2390  constructBlocks.insert(successor);
2391  }
2392 }
2393 
2394 LogicalResult ControlFlowStructurizer::structurize() {
2395  Operation *op = nullptr;
2396  bool isLoop = continueBlock != nullptr;
2397  if (isLoop) {
2398  if (auto loopOp = createLoopOp(control))
2399  op = loopOp.getOperation();
2400  } else {
2401  if (auto selectionOp = createSelectionOp(control))
2402  op = selectionOp.getOperation();
2403  }
2404  if (!op)
2405  return failure();
2406  Region &body = op->getRegion(0);
2407 
2408  IRMapping mapper;
2409  // All references to the old merge block should be directed to the
2410  // selection/loop merge block in the SelectionOp/LoopOp's region.
2411  mapper.map(mergeBlock, &body.back());
2412 
2413  collectBlocksInConstruct();
2414 
2415  // We've identified all blocks belonging to the selection/loop's region. Now
2416  // need to "move" them into the selection/loop. Instead of really moving the
2417  // blocks, in the following we copy them and remap all values and branches.
2418  // This is because:
2419  // * Inserting a block into a region requires the block not in any region
2420  // before. But selections/loops can nest so we can create selection/loop ops
2421  // in a nested manner, which means some blocks may already be in a
2422  // selection/loop region when to be moved again.
2423  // * It's much trickier to fix up the branches into and out of the loop's
2424  // region: we need to treat not-moved blocks and moved blocks differently:
2425  // Not-moved blocks jumping to the loop header block need to jump to the
2426  // merge point containing the new loop op but not the loop continue block's
2427  // back edge. Moved blocks jumping out of the loop need to jump to the
2428  // merge block inside the loop region but not other not-moved blocks.
2429  // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2430  // logic.
2431 
2432  // Create a corresponding block in the SelectionOp/LoopOp's region for each
2433  // block in this loop construct.
2434  OpBuilder builder(body);
2435  for (auto *block : constructBlocks) {
2436  // Create a block and insert it before the selection/loop merge block in the
2437  // SelectionOp/LoopOp's region.
2438  auto *newBlock = builder.createBlock(&body.back());
2439  mapper.map(block, newBlock);
2440  LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2441  << " from block " << block << "\n");
2442  if (!isFnEntryBlock(block)) {
2443  for (BlockArgument blockArg : block->getArguments()) {
2444  auto newArg =
2445  newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2446  mapper.map(blockArg, newArg);
2447  LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2448  << blockArg << " to " << newArg << "\n");
2449  }
2450  } else {
2451  LLVM_DEBUG(logger.startLine()
2452  << "[cf] block " << block << " is a function entry block\n");
2453  }
2454 
2455  for (auto &op : *block)
2456  newBlock->push_back(op.clone(mapper));
2457  }
2458 
2459  // Go through all ops and remap the operands.
2460  auto remapOperands = [&](Operation *op) {
2461  for (auto &operand : op->getOpOperands())
2462  if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2463  operand.set(mappedOp);
2464  for (auto &succOp : op->getBlockOperands())
2465  if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2466  succOp.set(mappedOp);
2467  };
2468  for (auto &block : body)
2469  block.walk(remapOperands);
2470 
2471  // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2472  // the selection/loop construct into its region. Next we need to fix the
2473  // connections between this new SelectionOp/LoopOp with existing blocks.
2474 
2475  // All existing incoming branches should go to the merge block, where the
2476  // SelectionOp/LoopOp resides right now.
2477  headerBlock->replaceAllUsesWith(mergeBlock);
2478 
2479  LLVM_DEBUG({
2480  logger.startLine() << "[cf] after cloning and fixing references:\n";
2481  headerBlock->getParentOp()->print(logger.getOStream());
2482  logger.startLine() << "\n";
2483  });
2484 
2485  if (isLoop) {
2486  if (!mergeBlock->args_empty()) {
2487  return mergeBlock->getParentOp()->emitError(
2488  "OpPhi in loop merge block unsupported");
2489  }
2490 
2491  // The loop header block may have block arguments. Since now we place the
2492  // loop op inside the old merge block, we need to make sure the old merge
2493  // block has the same block argument list.
2494  for (BlockArgument blockArg : headerBlock->getArguments())
2495  mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2496 
2497  // If the loop header block has block arguments, make sure the spirv.Branch
2498  // op matches.
2499  SmallVector<Value, 4> blockArgs;
2500  if (!headerBlock->args_empty())
2501  blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2502 
2503  // The loop entry block should have a unconditional branch jumping to the
2504  // loop header block.
2505  builder.setInsertionPointToEnd(&body.front());
2506  spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2507  ArrayRef<Value>(blockArgs));
2508  }
2509 
2510  // Values defined inside the selection region that need to be yielded outside
2511  // the region.
2512  SmallVector<Value> valuesToYield;
2513  // Outside uses of values that were sunk into the selection region. Those uses
2514  // will be replaced with values returned by the SelectionOp.
2515  SmallVector<Value> outsideUses;
2516 
2517  // Move block arguments of the original block (`mergeBlock`) into the merge
2518  // block inside the selection (`body.back()`). Values produced by block
2519  // arguments will be yielded by the selection region. We do not update uses or
2520  // erase original block arguments yet. It will be done later in the code.
2521  //
2522  // Code below is not executed for loops as it would interfere with the logic
2523  // above. Currently block arguments in the merge block are not supported, but
2524  // instead, the code above copies those arguments from the header block into
2525  // the merge block. As such, running the code would yield those copied
2526  // arguments that is most likely not a desired behaviour. This may need to be
2527  // revisited in the future.
2528  if (!isLoop)
2529  for (BlockArgument blockArg : mergeBlock->getArguments()) {
2530  // Create new block arguments in the last block ("merge block") of the
2531  // selection region. We create one argument for each argument in
2532  // `mergeBlock`. This new value will need to be yielded, and the original
2533  // value replaced, so add them to appropriate vectors.
2534  body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2535  valuesToYield.push_back(body.back().getArguments().back());
2536  outsideUses.push_back(blockArg);
2537  }
2538 
2539  // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2540  // cleaned up.
2541  LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2542  // First we need to drop all operands' references inside all blocks. This is
2543  // needed because we can have blocks referencing SSA values from one another.
2544  for (auto *block : constructBlocks)
2545  block->dropAllReferences();
2546 
2547  // All internal uses should be removed from original blocks by now, so
2548  // whatever is left is an outside use and will need to be yielded from
2549  // the newly created selection / loop region.
2550  for (Block *block : constructBlocks) {
2551  for (Operation &op : *block) {
2552  if (!op.use_empty())
2553  for (Value result : op.getResults()) {
2554  valuesToYield.push_back(mapper.lookupOrNull(result));
2555  outsideUses.push_back(result);
2556  }
2557  }
2558  for (BlockArgument &arg : block->getArguments()) {
2559  if (!arg.use_empty()) {
2560  valuesToYield.push_back(mapper.lookupOrNull(arg));
2561  outsideUses.push_back(arg);
2562  }
2563  }
2564  }
2565 
2566  assert(valuesToYield.size() == outsideUses.size());
2567 
2568  // If we need to yield any values from the selection / loop region we will
2569  // take care of it here.
2570  if (!valuesToYield.empty()) {
2571  LLVM_DEBUG(logger.startLine()
2572  << "[cf] yielding values from the selection / loop region\n");
2573 
2574  // Update `mlir.merge` with values to be yield.
2575  auto mergeOps = body.back().getOps<spirv::MergeOp>();
2576  Operation *merge = llvm::getSingleElement(mergeOps);
2577  assert(merge);
2578  merge->setOperands(valuesToYield);
2579 
2580  // MLIR does not allow changing the number of results of an operation, so
2581  // we create a new SelectionOp / LoopOp with required list of results and
2582  // move the region from the initial SelectionOp / LoopOp. The initial
2583  // operation is then removed. Since we move the region to the new op all
2584  // links between blocks and remapping we have previously done should be
2585  // preserved.
2586  builder.setInsertionPoint(&mergeBlock->front());
2587 
2588  Operation *newOp = nullptr;
2589 
2590  if (isLoop)
2591  newOp = spirv::LoopOp::create(builder, location,
2592  TypeRange(ValueRange(outsideUses)),
2593  static_cast<spirv::LoopControl>(control));
2594  else
2595  newOp = spirv::SelectionOp::create(
2596  builder, location, TypeRange(ValueRange(outsideUses)),
2597  static_cast<spirv::SelectionControl>(control));
2598 
2599  newOp->getRegion(0).takeBody(body);
2600 
2601  // Remove initial op and swap the pointer to the newly created one.
2602  op->erase();
2603  op = newOp;
2604 
2605  // Update all outside uses to use results of the SelectionOp / LoopOp and
2606  // remove block arguments from the original merge block.
2607  for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2608  outsideUses[i].replaceAllUsesWith(op->getResult(i));
2609 
2610  // We do not support block arguments in loop merge block. Also running this
2611  // function with loop would break some of the loop specific code above
2612  // dealing with block arguments.
2613  if (!isLoop)
2614  mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2615  }
2616 
2617  // Check that whether some op in the to-be-erased blocks still has uses. Those
2618  // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2619  // region. We cannot handle such cases given that once a value is sinked into
2620  // the SelectionOp/LoopOp's region, there is no escape for it.
2621  for (auto *block : constructBlocks) {
2622  for (Operation &op : *block)
2623  if (!op.use_empty())
2624  return op.emitOpError("failed control flow structurization: value has "
2625  "uses outside of the "
2626  "enclosing selection/loop construct");
2627  for (BlockArgument &arg : block->getArguments())
2628  if (!arg.use_empty())
2629  return emitError(arg.getLoc(), "failed control flow structurization: "
2630  "block argument has uses outside of the "
2631  "enclosing selection/loop construct");
2632  }
2633 
2634  // Then erase all old blocks.
2635  for (auto *block : constructBlocks) {
2636  // We've cloned all blocks belonging to this construct into the structured
2637  // control flow op's region. Among these blocks, some may compose another
2638  // selection/loop. If so, they will be recorded within blockMergeInfo.
2639  // We need to update the pointers there to the newly remapped ones so we can
2640  // continue structurizing them later.
2641  //
2642  // We need to walk each block as constructBlocks do not include blocks
2643  // internal to ops already structured within those blocks. It is not
2644  // fully clear to me why the mergeInfo of blocks (yet to be structured)
2645  // inside already structured selections/loops get invalidated and needs
2646  // updating, however the following example code can cause a crash (depending
2647  // on the structuring order), when the most inner selection is being
2648  // structured after the outer selection and loop have been already
2649  // structured:
2650  //
2651  // spirv.mlir.for {
2652  // // ...
2653  // spirv.mlir.selection {
2654  // // ..
2655  // // A selection region that hasn't been yet structured!
2656  // // ..
2657  // }
2658  // // ...
2659  // }
2660  //
2661  // If the loop gets structured after the outer selection, but before the
2662  // inner selection. Moving the already structured selection inside the loop
2663  // will invalidate the mergeInfo of the region that is not yet structured.
2664  // Just going over constructBlocks will not check and updated header blocks
2665  // inside the already structured selection region. Walking block fixes that.
2666  //
2667  // TODO: If structuring was done in a fixed order starting with inner
2668  // most constructs this most likely not be an issue and the whole code
2669  // section could be removed. However, with the current non-deterministic
2670  // order this is not possible.
2671  //
2672  // TODO: The asserts in the following assumes input SPIR-V blob forms
2673  // correctly nested selection/loop constructs. We should relax this and
2674  // support error cases better.
2675  auto updateMergeInfo = [&](Block *block) -> WalkResult {
2676  auto it = blockMergeInfo.find(block);
2677  if (it != blockMergeInfo.end()) {
2678  // Use the original location for nested selection/loop ops.
2679  Location loc = it->second.loc;
2680 
2681  Block *newHeader = mapper.lookupOrNull(block);
2682  if (!newHeader)
2683  return emitError(loc, "failed control flow structurization: nested "
2684  "loop header block should be remapped!");
2685 
2686  Block *newContinue = it->second.continueBlock;
2687  if (newContinue) {
2688  newContinue = mapper.lookupOrNull(newContinue);
2689  if (!newContinue)
2690  return emitError(loc, "failed control flow structurization: nested "
2691  "loop continue block should be remapped!");
2692  }
2693 
2694  Block *newMerge = it->second.mergeBlock;
2695  if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2696  newMerge = mappedTo;
2697 
2698  // The iterator should be erased before adding a new entry into
2699  // blockMergeInfo to avoid iterator invalidation.
2700  blockMergeInfo.erase(it);
2701  blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2702  newContinue);
2703  }
2704 
2705  return WalkResult::advance();
2706  };
2707 
2708  if (block->walk(updateMergeInfo).wasInterrupted())
2709  return failure();
2710 
2711  // The structured selection/loop's entry block does not have arguments.
2712  // If the function's header block is also part of the structured control
2713  // flow, we cannot just simply erase it because it may contain arguments
2714  // matching the function signature and used by the cloned blocks.
2715  if (isFnEntryBlock(block)) {
2716  LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2717  << " to only contain a spirv.Branch op\n");
2718  // Still keep the function entry block for the potential block arguments,
2719  // but replace all ops inside with a branch to the merge block.
2720  block->clear();
2721  builder.setInsertionPointToEnd(block);
2722  spirv::BranchOp::create(builder, location, mergeBlock);
2723  } else {
2724  LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2725  block->erase();
2726  }
2727  }
2728 
2729  LLVM_DEBUG(logger.startLine()
2730  << "[cf] after structurizing construct with header block "
2731  << headerBlock << ":\n"
2732  << *op << "\n");
2733 
2734  return success();
2735 }
2736 
2737 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2738  LLVM_DEBUG({
2739  logger.startLine()
2740  << "//----- [phi] start wiring up block arguments -----//\n";
2741  logger.indent();
2742  });
2743 
2744  OpBuilder::InsertionGuard guard(opBuilder);
2745 
2746  for (const auto &info : blockPhiInfo) {
2747  Block *block = info.first.first;
2748  Block *target = info.first.second;
2749  const BlockPhiInfo &phiInfo = info.second;
2750  LLVM_DEBUG({
2751  logger.startLine() << "[phi] block " << block << "\n";
2752  logger.startLine() << "[phi] before creating block argument:\n";
2753  block->getParentOp()->print(logger.getOStream());
2754  logger.startLine() << "\n";
2755  });
2756 
2757  // Set insertion point to before this block's terminator early because we
2758  // may materialize ops via getValue() call.
2759  auto *op = block->getTerminator();
2760  opBuilder.setInsertionPoint(op);
2761 
2762  SmallVector<Value, 4> blockArgs;
2763  blockArgs.reserve(phiInfo.size());
2764  for (uint32_t valueId : phiInfo) {
2765  if (Value value = getValue(valueId)) {
2766  blockArgs.push_back(value);
2767  LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2768  << " id = " << valueId << "\n");
2769  } else {
2770  return emitError(unknownLoc, "OpPhi references undefined value!");
2771  }
2772  }
2773 
2774  if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2775  // Replace the previous branch op with a new one with block arguments.
2776  spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2777  branchOp.getTarget(), blockArgs);
2778  branchOp.erase();
2779  } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2780  assert((branchCondOp.getTrueBlock() == target ||
2781  branchCondOp.getFalseBlock() == target) &&
2782  "expected target to be either the true or false target");
2783  if (target == branchCondOp.getTrueTarget())
2784  spirv::BranchConditionalOp::create(
2785  opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2786  blockArgs, branchCondOp.getFalseBlockArguments(),
2787  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2788  branchCondOp.getFalseTarget());
2789  else
2790  spirv::BranchConditionalOp::create(
2791  opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2792  branchCondOp.getTrueBlockArguments(), blockArgs,
2793  branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2794  branchCondOp.getFalseBlock());
2795 
2796  branchCondOp.erase();
2797  } else {
2798  return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2799  }
2800 
2801  LLVM_DEBUG({
2802  logger.startLine() << "[phi] after creating block argument:\n";
2803  block->getParentOp()->print(logger.getOStream());
2804  logger.startLine() << "\n";
2805  });
2806  }
2807  blockPhiInfo.clear();
2808 
2809  LLVM_DEBUG({
2810  logger.unindent();
2811  logger.startLine()
2812  << "//--- [phi] completed wiring up block arguments ---//\n";
2813  });
2814  return success();
2815 }
2816 
2817 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2818  // Create a copy, so we can modify keys in the original.
2819  BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2820  for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2821  it != e; ++it) {
2822  auto &[block, mergeInfo] = *it;
2823 
2824  // Skip processing loop regions. For loop regions continueBlock is non-null.
2825  if (mergeInfo.continueBlock)
2826  continue;
2827 
2828  if (!block->mightHaveTerminator())
2829  continue;
2830 
2831  Operation *terminator = block->getTerminator();
2832  assert(terminator);
2833 
2834  if (!isa<spirv::BranchConditionalOp>(terminator))
2835  continue;
2836 
2837  // Check if the current header block is a merge block of another construct.
2838  bool splitHeaderMergeBlock = false;
2839  for (const auto &[_, mergeInfo] : blockMergeInfo) {
2840  if (mergeInfo.mergeBlock == block)
2841  splitHeaderMergeBlock = true;
2842  }
2843 
2844  // Do not split a block that only contains a conditional branch, unless it
2845  // is also a merge block of another construct - in that case we want to
2846  // split the block. We do not want two constructs to share header / merge
2847  // block.
2848  if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2849  Block *newBlock = block->splitBlock(terminator);
2850  OpBuilder builder(block, block->end());
2851  spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2852 
2853  // After splitting we need to update the map to use the new block as a
2854  // header.
2855  blockMergeInfo.erase(block);
2856  blockMergeInfo.try_emplace(newBlock, mergeInfo);
2857  }
2858  }
2859 
2860  return success();
2861 }
2862 
2863 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2864  if (!options.enableControlFlowStructurization) {
2865  LLVM_DEBUG(
2866  {
2867  logger.startLine()
2868  << "//----- [cf] skip structurizing control flow -----//\n";
2869  logger.indent();
2870  });
2871  return success();
2872  }
2873 
2874  LLVM_DEBUG({
2875  logger.startLine()
2876  << "//----- [cf] start structurizing control flow -----//\n";
2877  logger.indent();
2878  });
2879 
2880  LLVM_DEBUG({
2881  logger.startLine() << "[cf] split conditional blocks\n";
2882  logger.startLine() << "\n";
2883  });
2884 
2885  if (failed(splitConditionalBlocks())) {
2886  return failure();
2887  }
2888 
2889  // TODO: This loop is non-deterministic. Iteration order may vary between runs
2890  // for the same shader as the key to the map is a pointer. See:
2891  // https://github.com/llvm/llvm-project/issues/128547
2892  while (!blockMergeInfo.empty()) {
2893  Block *headerBlock = blockMergeInfo.begin()->first;
2894  BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2895 
2896  LLVM_DEBUG({
2897  logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2898  headerBlock->print(logger.getOStream());
2899  logger.startLine() << "\n";
2900  });
2901 
2902  auto *mergeBlock = mergeInfo.mergeBlock;
2903  assert(mergeBlock && "merge block cannot be nullptr");
2904  if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2905  return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2906  LLVM_DEBUG({
2907  logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2908  mergeBlock->print(logger.getOStream());
2909  logger.startLine() << "\n";
2910  });
2911 
2912  auto *continueBlock = mergeInfo.continueBlock;
2913  LLVM_DEBUG(if (continueBlock) {
2914  logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2915  continueBlock->print(logger.getOStream());
2916  logger.startLine() << "\n";
2917  });
2918  // Erase this case before calling into structurizer, who will update
2919  // blockMergeInfo.
2920  blockMergeInfo.erase(blockMergeInfo.begin());
2921  ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2922  blockMergeInfo, headerBlock,
2923  mergeBlock, continueBlock
2924 #ifndef NDEBUG
2925  ,
2926  logger
2927 #endif
2928  );
2929  if (failed(structurizer.structurize()))
2930  return failure();
2931  }
2932 
2933  LLVM_DEBUG({
2934  logger.unindent();
2935  logger.startLine()
2936  << "//--- [cf] completed structurizing control flow ---//\n";
2937  });
2938  return success();
2939 }
2940 
2941 //===----------------------------------------------------------------------===//
2942 // Debug
2943 //===----------------------------------------------------------------------===//
2944 
2945 Location spirv::Deserializer::createFileLineColLoc(OpBuilder opBuilder) {
2946  if (!debugLine)
2947  return unknownLoc;
2948 
2949  auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2950  if (fileName.empty())
2951  fileName = "<unknown>";
2952  return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
2953  debugLine->column);
2954 }
2955 
2956 LogicalResult
2957 spirv::Deserializer::processDebugLine(ArrayRef<uint32_t> operands) {
2958  // According to SPIR-V spec:
2959  // "This location information applies to the instructions physically
2960  // following this instruction, up to the first occurrence of any of the
2961  // following: the next end of block, the next OpLine instruction, or the next
2962  // OpNoLine instruction."
2963  if (operands.size() != 3)
2964  return emitError(unknownLoc, "OpLine must have 3 operands");
2965  debugLine = DebugLine{operands[0], operands[1], operands[2]};
2966  return success();
2967 }
2968 
2969 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2970 
2971 LogicalResult
2972 spirv::Deserializer::processDebugString(ArrayRef<uint32_t> operands) {
2973  if (operands.size() < 2)
2974  return emitError(unknownLoc, "OpString needs at least 2 operands");
2975 
2976  if (!debugInfoMap.lookup(operands[0]).empty())
2977  return emitError(unknownLoc,
2978  "duplicate debug string found for result <id> ")
2979  << operands[0];
2980 
2981  unsigned wordIndex = 1;
2982  StringRef debugString = decodeStringLiteral(operands, wordIndex);
2983  if (wordIndex != operands.size())
2984  return emitError(unknownLoc,
2985  "unexpected trailing words in OpString instruction");
2986 
2987  debugInfoMap[operands[0]] = debugString;
2988  return success();
2989 }
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
int64_t rows
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
void erase()
Unlink this Block from its parent region and delete it.
Definition: Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition: Block.cpp:318
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
void print(raw_ostream &os)
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition: Block.cpp:250
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
void push_back(Operation *op)
Definition: Block.h:149
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition: Location.cpp:157
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
Definition: IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition: IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition: IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:718
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition: Operation.cpp:66
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
MutableArrayRef< BlockOperand > getBlockOperands()
Definition: Operation.h:695
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:383
result_range getResults()
Definition: Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:538
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & back()
Definition: Region.h:64
BlockListType & getBlocks()
Definition: Region.h:45
Location getLoc()
Return a location for this region.
Definition: Region.cpp:31
BlockListType::iterator iterator
Definition: Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition: Region.h:241
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: WalkResult.h:29
static WalkResult advance()
Definition: WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:157
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
Definition: SPIRVTypes.cpp:313
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition: SPIRVTypes.h:155
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
Definition: SPIRVTypes.cpp:495
static RuntimeArrayType get(Type elementType)
Definition: SPIRVTypes.cpp:553
static SampledImageType get(Type imageType)
Definition: SPIRVTypes.cpp:789
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition: CallGraph.h:229
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
Definition: Deserializer.h:61
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.