MLIR  22.0.0git
Serializer.cpp
Go to the documentation of this file.
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
27 #include <cstdint>
28 #include <optional>
29 
30 #define DEBUG_TYPE "spirv-serialization"
31 
32 using namespace mlir;
33 
34 /// Returns the merge block if the given `op` is a structured control flow op.
35 /// Otherwise returns nullptr.
37  if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38  return selectionOp.getMergeBlock();
39  if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40  return loopOp.getMergeBlock();
41  return nullptr;
42 }
43 
44 /// Given a predecessor `block` for a block with arguments, returns the block
45 /// that should be used as the parent block for SPIR-V OpPhi instructions
46 /// corresponding to the block arguments.
47 static Block *getPhiIncomingBlock(Block *block) {
48  // If the predecessor block in question is the entry block for a
49  // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50  if (block->isEntryBlock()) {
51  if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52  // Then the incoming parent block for OpPhi should be the merge block of
53  // the structured control flow op before this loop.
54  Operation *op = loopOp.getOperation();
55  while ((op = op->getPrevNode()) != nullptr)
56  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57  return incomingBlock;
58  // Or the enclosing block itself if no structured control flow ops
59  // exists before this loop.
60  return loopOp->getBlock();
61  }
62  }
63 
64  // Otherwise, we jump from the given predecessor block. Try to see if there is
65  // a structured control flow op inside it.
66  for (Operation &op : llvm::reverse(block->getOperations())) {
67  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68  return incomingBlock;
69  }
70  return block;
71 }
72 
73 static bool isZeroValue(Attribute attr) {
74  if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75  return floatAttr.getValue().isZero();
76  }
77  if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78  return !boolAttr.getValue();
79  }
80  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81  return intAttr.getValue().isZero();
82  }
83  if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
84  return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
85  }
86  if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
87  return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
88  }
89  return false;
90 }
91 
92 namespace mlir {
93 namespace spirv {
94 
95 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
96 /// the given `binary` vector.
97 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
98  ArrayRef<uint32_t> operands) {
99  uint32_t wordCount = 1 + operands.size();
100  binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
101  binary.append(operands.begin(), operands.end());
102 }
103 
104 Serializer::Serializer(spirv::ModuleOp module,
106  : module(module), mlirBuilder(module.getContext()), options(options) {}
107 
108 LogicalResult Serializer::serialize() {
109  LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
110 
111  if (failed(module.verifyInvariants()))
112  return failure();
113 
114  // TODO: handle the other sections
115  processCapability();
116  if (failed(processExtension())) {
117  return failure();
118  }
119  processMemoryModel();
120  processDebugInfo();
121 
122  // Iterate over the module body to serialize it. Assumptions are that there is
123  // only one basic block in the moduleOp
124  for (auto &op : *module.getBody()) {
125  if (failed(processOperation(&op))) {
126  return failure();
127  }
128  }
129 
130  LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
131  return success();
132 }
133 
135  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
136  extensions.size() + extendedSets.size() +
137  memoryModel.size() + entryPoints.size() +
138  executionModes.size() + decorations.size() +
139  typesGlobalValues.size() + functions.size();
140 
141  binary.clear();
142  binary.reserve(moduleSize);
143 
144  spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
145  nextID);
146  binary.append(capabilities.begin(), capabilities.end());
147  binary.append(extensions.begin(), extensions.end());
148  binary.append(extendedSets.begin(), extendedSets.end());
149  binary.append(memoryModel.begin(), memoryModel.end());
150  binary.append(entryPoints.begin(), entryPoints.end());
151  binary.append(executionModes.begin(), executionModes.end());
152  binary.append(debug.begin(), debug.end());
153  binary.append(names.begin(), names.end());
154  binary.append(decorations.begin(), decorations.end());
155  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
156  binary.append(functions.begin(), functions.end());
157 }
158 
159 #ifndef NDEBUG
160 void Serializer::printValueIDMap(raw_ostream &os) {
161  os << "\n= Value <id> Map =\n\n";
162  for (auto valueIDPair : valueIDMap) {
163  Value val = valueIDPair.first;
164  os << " " << val << " "
165  << "id = " << valueIDPair.second << ' ';
166  if (auto *op = val.getDefiningOp()) {
167  os << "from op '" << op->getName() << "'";
168  } else if (auto arg = dyn_cast<BlockArgument>(val)) {
169  Block *block = arg.getOwner();
170  os << "from argument of block " << block << ' ';
171  os << " in op '" << block->getParentOp()->getName() << "'";
172  }
173  os << '\n';
174  }
175 }
176 #endif
177 
178 //===----------------------------------------------------------------------===//
179 // Module structure
180 //===----------------------------------------------------------------------===//
181 
182 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
183  auto funcID = funcIDMap.lookup(fnName);
184  if (!funcID) {
185  funcID = getNextID();
186  funcIDMap[fnName] = funcID;
187  }
188  return funcID;
189 }
190 
191 void Serializer::processCapability() {
192  for (auto cap : module.getVceTriple()->getCapabilities())
193  encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
194  {static_cast<uint32_t>(cap)});
195 }
196 
197 void Serializer::processDebugInfo() {
198  if (!options.emitDebugInfo)
199  return;
200  auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
201  auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
202  fileID = getNextID();
203  SmallVector<uint32_t, 16> operands;
204  operands.push_back(fileID);
205  spirv::encodeStringLiteralInto(operands, fileName);
206  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
207  // TODO: Encode more debug instructions.
208 }
209 
210 LogicalResult Serializer::processExtension() {
212  llvm::SmallSet<Extension, 4> deducedExts(
213  llvm::from_range, module.getVceTriple()->getExtensions());
214  auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
215  if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
216  TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
217  if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
218  return module.emitError(
219  "SPV_KHR_non_semantic_info extension not available");
220  deducedExts.insert(nonSemanticInfoExt);
221  }
222  for (spirv::Extension ext : deducedExts) {
223  extName.clear();
224  spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
225  encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
226  }
227  return success();
228 }
229 
230 void Serializer::processMemoryModel() {
231  StringAttr memoryModelName = module.getMemoryModelAttrName();
232  auto mm = static_cast<uint32_t>(
233  module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
234  .getValue());
235 
236  StringAttr addressingModelName = module.getAddressingModelAttrName();
237  auto am = static_cast<uint32_t>(
238  module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
239  .getValue());
240 
241  encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
242 }
243 
244 static std::string getDecorationName(StringRef attrName) {
245  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
246  // expected FPFastMathMode.
247  if (attrName == "fp_fast_math_mode")
248  return "FPFastMathMode";
249  // similar here
250  if (attrName == "fp_rounding_mode")
251  return "FPRoundingMode";
252  // convertToCamelFromSnakeCase will not capitalize "INTEL".
253  if (attrName == "cache_control_load_intel")
254  return "CacheControlLoadINTEL";
255  if (attrName == "cache_control_store_intel")
256  return "CacheControlStoreINTEL";
257 
258  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
259 }
260 
261 template <typename AttrTy, typename EmitF>
262 LogicalResult processDecorationList(Location loc, Decoration decoration,
263  Attribute attrList, StringRef attrName,
264  EmitF emitter) {
265  auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
266  if (!arrayAttr) {
267  return emitError(loc, "expecting array attribute of ")
268  << attrName << " for " << stringifyDecoration(decoration);
269  }
270  if (arrayAttr.empty()) {
271  return emitError(loc, "expecting non-empty array attribute of ")
272  << attrName << " for " << stringifyDecoration(decoration);
273  }
274  for (Attribute attr : arrayAttr.getValue()) {
275  auto cacheControlAttr = dyn_cast<AttrTy>(attr);
276  if (!cacheControlAttr) {
277  return emitError(loc, "expecting array attribute of ")
278  << attrName << " for " << stringifyDecoration(decoration);
279  }
280  // This named attribute encodes several decorations. Emit one per
281  // element in the array.
282  if (failed(emitter(cacheControlAttr)))
283  return failure();
284  }
285  return success();
286 }
287 
288 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
289  Decoration decoration,
290  Attribute attr) {
292  switch (decoration) {
293  case spirv::Decoration::LinkageAttributes: {
294  // Get the value of the Linkage Attributes
295  // e.g., LinkageAttributes=["linkageName", linkageType].
296  auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
297  auto linkageName = linkageAttr.getLinkageName();
298  auto linkageType = linkageAttr.getLinkageType().getValue();
299  // Encode the Linkage Name (string literal to uint32_t).
300  spirv::encodeStringLiteralInto(args, linkageName);
301  // Encode LinkageType & Add the Linkagetype to the args.
302  args.push_back(static_cast<uint32_t>(linkageType));
303  break;
304  }
305  case spirv::Decoration::FPFastMathMode:
306  if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
307  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
308  break;
309  }
310  return emitError(loc, "expected FPFastMathModeAttr attribute for ")
311  << stringifyDecoration(decoration);
312  case spirv::Decoration::FPRoundingMode:
313  if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
314  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
315  break;
316  }
317  return emitError(loc, "expected FPRoundingModeAttr attribute for ")
318  << stringifyDecoration(decoration);
319  case spirv::Decoration::Binding:
320  case spirv::Decoration::DescriptorSet:
321  case spirv::Decoration::Location:
322  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
323  args.push_back(intAttr.getValue().getZExtValue());
324  break;
325  }
326  return emitError(loc, "expected integer attribute for ")
327  << stringifyDecoration(decoration);
328  case spirv::Decoration::BuiltIn:
329  if (auto strAttr = dyn_cast<StringAttr>(attr)) {
330  auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
331  if (enumVal) {
332  args.push_back(static_cast<uint32_t>(*enumVal));
333  break;
334  }
335  return emitError(loc, "invalid ")
336  << stringifyDecoration(decoration) << " decoration attribute "
337  << strAttr.getValue();
338  }
339  return emitError(loc, "expected string attribute for ")
340  << stringifyDecoration(decoration);
341  case spirv::Decoration::Aliased:
342  case spirv::Decoration::AliasedPointer:
343  case spirv::Decoration::Flat:
344  case spirv::Decoration::NonReadable:
345  case spirv::Decoration::NonWritable:
346  case spirv::Decoration::NoPerspective:
347  case spirv::Decoration::NoSignedWrap:
348  case spirv::Decoration::NoUnsignedWrap:
349  case spirv::Decoration::RelaxedPrecision:
350  case spirv::Decoration::Restrict:
351  case spirv::Decoration::RestrictPointer:
352  case spirv::Decoration::NoContraction:
353  case spirv::Decoration::Constant:
354  case spirv::Decoration::Block:
355  case spirv::Decoration::Invariant:
356  case spirv::Decoration::Patch:
357  // For unit attributes and decoration attributes, the args list
358  // has no values so we do nothing.
359  if (isa<UnitAttr, DecorationAttr>(attr))
360  break;
361  return emitError(loc,
362  "expected unit attribute or decoration attribute for ")
363  << stringifyDecoration(decoration);
364  case spirv::Decoration::CacheControlLoadINTEL:
365  return processDecorationList<CacheControlLoadINTELAttr>(
366  loc, decoration, attr, "CacheControlLoadINTEL",
367  [&](CacheControlLoadINTELAttr attr) {
368  unsigned cacheLevel = attr.getCacheLevel();
369  LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
370  return emitDecoration(
371  resultID, decoration,
372  {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
373  });
374  case spirv::Decoration::CacheControlStoreINTEL:
375  return processDecorationList<CacheControlStoreINTELAttr>(
376  loc, decoration, attr, "CacheControlStoreINTEL",
377  [&](CacheControlStoreINTELAttr attr) {
378  unsigned cacheLevel = attr.getCacheLevel();
379  StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
380  return emitDecoration(
381  resultID, decoration,
382  {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
383  });
384  default:
385  return emitError(loc, "unhandled decoration ")
386  << stringifyDecoration(decoration);
387  }
388  return emitDecoration(resultID, decoration, args);
389 }
390 
391 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
392  NamedAttribute attr) {
393  StringRef attrName = attr.getName().strref();
394  std::string decorationName = getDecorationName(attrName);
395  std::optional<Decoration> decoration =
396  spirv::symbolizeDecoration(decorationName);
397  if (!decoration) {
398  return emitError(
399  loc, "non-argument attributes expected to have snake-case-ified "
400  "decoration name, unhandled attribute with name : ")
401  << attrName;
402  }
403  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
404 }
405 
406 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
407  assert(!name.empty() && "unexpected empty string for OpName");
408  if (!options.emitSymbolName)
409  return success();
410 
411  SmallVector<uint32_t, 4> nameOperands;
412  nameOperands.push_back(resultID);
413  spirv::encodeStringLiteralInto(nameOperands, name);
414  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
415  return success();
416 }
417 
418 template <>
419 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
420  Location loc, spirv::ArrayType type, uint32_t resultID) {
421  if (unsigned stride = type.getArrayStride()) {
422  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
423  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
424  }
425  return success();
426 }
427 
428 template <>
429 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
430  Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
431  if (unsigned stride = type.getArrayStride()) {
432  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
433  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
434  }
435  return success();
436 }
437 
438 LogicalResult Serializer::processMemberDecoration(
439  uint32_t structID,
440  const spirv::StructType::MemberDecorationInfo &memberDecoration) {
442  {structID, memberDecoration.memberIndex,
443  static_cast<uint32_t>(memberDecoration.decoration)});
444  if (memberDecoration.hasValue()) {
445  args.push_back(
446  cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
447  }
448  encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
449  return success();
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // Type
454 //===----------------------------------------------------------------------===//
455 
456 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
457 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
458 // PushConstant Storage Classes must be explicitly laid out."
459 bool Serializer::isInterfaceStructPtrType(Type type) const {
460  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
461  switch (ptrType.getStorageClass()) {
462  case spirv::StorageClass::PhysicalStorageBuffer:
463  case spirv::StorageClass::PushConstant:
464  case spirv::StorageClass::StorageBuffer:
465  case spirv::StorageClass::Uniform:
466  return isa<spirv::StructType>(ptrType.getPointeeType());
467  default:
468  break;
469  }
470  }
471  return false;
472 }
473 
474 LogicalResult Serializer::processType(Location loc, Type type,
475  uint32_t &typeID) {
476  // Maintains a set of names for nested identified struct types. This is used
477  // to properly serialize recursive references.
478  SetVector<StringRef> serializationCtx;
479  return processTypeImpl(loc, type, typeID, serializationCtx);
480 }
481 
482 LogicalResult
483 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
484  SetVector<StringRef> &serializationCtx) {
485 
486  // Map unsigned integer types to singless integer types.
487  // This is needed otherwise the generated spirv assembly will contain
488  // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
489  // such module fails validation. Indeed at MLIR level the two types are
490  // different and lookup in the cache below misses.
491  // Note: This conversion needs to happen here before the type is looked up in
492  // the cache.
493  if (type.isUnsignedInteger()) {
494  type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
495  IntegerType::SignednessSemantics::Signless);
496  }
497 
498  typeID = getTypeID(type);
499  if (typeID)
500  return success();
501 
502  typeID = getNextID();
503  SmallVector<uint32_t, 4> operands;
504 
505  operands.push_back(typeID);
506  auto typeEnum = spirv::Opcode::OpTypeVoid;
507  bool deferSerialization = false;
508 
509  if ((isa<FunctionType>(type) &&
510  succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
511  operands))) ||
512  succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
513  deferSerialization, serializationCtx))) {
514  if (deferSerialization)
515  return success();
516 
517  typeIDMap[type] = typeID;
518 
519  encodeInstructionInto(typesGlobalValues, typeEnum, operands);
520 
521  if (recursiveStructInfos.count(type) != 0) {
522  // This recursive struct type is emitted already, now the OpTypePointer
523  // instructions referring to recursive references are emitted as well.
524  for (auto &ptrInfo : recursiveStructInfos[type]) {
525  // TODO: This might not work if more than 1 recursive reference is
526  // present in the struct.
527  SmallVector<uint32_t, 4> ptrOperands;
528  ptrOperands.push_back(ptrInfo.pointerTypeID);
529  ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
530  ptrOperands.push_back(typeIDMap[type]);
531 
532  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
533  ptrOperands);
534  }
535 
536  recursiveStructInfos[type].clear();
537  }
538 
539  return success();
540  }
541 
542  return failure();
543 }
544 
545 LogicalResult Serializer::prepareBasicType(
546  Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
547  SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
548  SetVector<StringRef> &serializationCtx) {
549  deferSerialization = false;
550 
551  if (isVoidType(type)) {
552  typeEnum = spirv::Opcode::OpTypeVoid;
553  return success();
554  }
555 
556  if (auto intType = dyn_cast<IntegerType>(type)) {
557  if (intType.getWidth() == 1) {
558  typeEnum = spirv::Opcode::OpTypeBool;
559  return success();
560  }
561 
562  typeEnum = spirv::Opcode::OpTypeInt;
563  operands.push_back(intType.getWidth());
564  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
565  // to preserve or validate.
566  // 0 indicates unsigned, or no signedness semantics
567  // 1 indicates signed semantics."
568  operands.push_back(intType.isSigned() ? 1 : 0);
569  return success();
570  }
571 
572  if (auto floatType = dyn_cast<FloatType>(type)) {
573  typeEnum = spirv::Opcode::OpTypeFloat;
574  operands.push_back(floatType.getWidth());
575  if (floatType.isBF16()) {
576  operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
577  }
578  return success();
579  }
580 
581  if (auto vectorType = dyn_cast<VectorType>(type)) {
582  uint32_t elementTypeID = 0;
583  if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
584  serializationCtx))) {
585  return failure();
586  }
587  typeEnum = spirv::Opcode::OpTypeVector;
588  operands.push_back(elementTypeID);
589  operands.push_back(vectorType.getNumElements());
590  return success();
591  }
592 
593  if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
594  typeEnum = spirv::Opcode::OpTypeImage;
595  uint32_t sampledTypeID = 0;
596  if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
597  return failure();
598 
599  llvm::append_values(operands, sampledTypeID,
600  static_cast<uint32_t>(imageType.getDim()),
601  static_cast<uint32_t>(imageType.getDepthInfo()),
602  static_cast<uint32_t>(imageType.getArrayedInfo()),
603  static_cast<uint32_t>(imageType.getSamplingInfo()),
604  static_cast<uint32_t>(imageType.getSamplerUseInfo()),
605  static_cast<uint32_t>(imageType.getImageFormat()));
606  return success();
607  }
608 
609  if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
610  typeEnum = spirv::Opcode::OpTypeArray;
611  uint32_t elementTypeID = 0;
612  if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
613  serializationCtx))) {
614  return failure();
615  }
616  operands.push_back(elementTypeID);
617  if (auto elementCountID = prepareConstantInt(
618  loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
619  operands.push_back(elementCountID);
620  }
621  return processTypeDecoration(loc, arrayType, resultID);
622  }
623 
624  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
625  uint32_t pointeeTypeID = 0;
626  spirv::StructType pointeeStruct =
627  dyn_cast<spirv::StructType>(ptrType.getPointeeType());
628 
629  if (pointeeStruct && pointeeStruct.isIdentified() &&
630  serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
631  // A recursive reference to an enclosing struct is found.
632  //
633  // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
634  // class as operands.
635  SmallVector<uint32_t, 2> forwardPtrOperands;
636  forwardPtrOperands.push_back(resultID);
637  forwardPtrOperands.push_back(
638  static_cast<uint32_t>(ptrType.getStorageClass()));
639 
640  encodeInstructionInto(typesGlobalValues,
641  spirv::Opcode::OpTypeForwardPointer,
642  forwardPtrOperands);
643 
644  // 2. Find the pointee (enclosing) struct.
645  auto structType = spirv::StructType::getIdentified(
646  module.getContext(), pointeeStruct.getIdentifier());
647 
648  if (!structType)
649  return failure();
650 
651  // 3. Mark the OpTypePointer that is supposed to be emitted by this call
652  // as deferred.
653  deferSerialization = true;
654 
655  // 4. Record the info needed to emit the deferred OpTypePointer
656  // instruction when the enclosing struct is completely serialized.
657  recursiveStructInfos[structType].push_back(
658  {resultID, ptrType.getStorageClass()});
659  } else {
660  if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
661  serializationCtx)))
662  return failure();
663  }
664 
665  typeEnum = spirv::Opcode::OpTypePointer;
666  operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
667  operands.push_back(pointeeTypeID);
668 
669  // TODO: Now struct decorations are supported this code may not be
670  // necessary. However, it is left to support backwards compatibility.
671  // Ideally, Block decorations should be inserted when converting to SPIR-V.
672  if (isInterfaceStructPtrType(ptrType)) {
673  auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
674  if (!structType.hasDecoration(spirv::Decoration::Block))
675  if (failed(emitDecoration(getTypeID(pointeeStruct),
676  spirv::Decoration::Block)))
677  return emitError(loc, "cannot decorate ")
678  << pointeeStruct << " with Block decoration";
679  }
680 
681  return success();
682  }
683 
684  if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
685  uint32_t elementTypeID = 0;
686  if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
687  elementTypeID, serializationCtx))) {
688  return failure();
689  }
690  typeEnum = spirv::Opcode::OpTypeRuntimeArray;
691  operands.push_back(elementTypeID);
692  return processTypeDecoration(loc, runtimeArrayType, resultID);
693  }
694 
695  if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
696  typeEnum = spirv::Opcode::OpTypeSampledImage;
697  uint32_t imageTypeID = 0;
698  if (failed(
699  processType(loc, sampledImageType.getImageType(), imageTypeID))) {
700  return failure();
701  }
702  operands.push_back(imageTypeID);
703  return success();
704  }
705 
706  if (auto structType = dyn_cast<spirv::StructType>(type)) {
707  if (structType.isIdentified()) {
708  if (failed(processName(resultID, structType.getIdentifier())))
709  return failure();
710  serializationCtx.insert(structType.getIdentifier());
711  }
712 
713  bool hasOffset = structType.hasOffset();
714  for (auto elementIndex :
715  llvm::seq<uint32_t>(0, structType.getNumElements())) {
716  uint32_t elementTypeID = 0;
717  if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
718  elementTypeID, serializationCtx))) {
719  return failure();
720  }
721  operands.push_back(elementTypeID);
722  if (hasOffset) {
723  auto intType = IntegerType::get(structType.getContext(), 32);
724  // Decorate each struct member with an offset
726  elementIndex, spirv::Decoration::Offset,
727  IntegerAttr::get(intType,
728  structType.getMemberOffset(elementIndex))};
729  if (failed(processMemberDecoration(resultID, offsetDecoration))) {
730  return emitError(loc, "cannot decorate ")
731  << elementIndex << "-th member of " << structType
732  << " with its offset";
733  }
734  }
735  }
737  structType.getMemberDecorations(memberDecorations);
738 
739  for (auto &memberDecoration : memberDecorations) {
740  if (failed(processMemberDecoration(resultID, memberDecoration))) {
741  return emitError(loc, "cannot decorate ")
742  << static_cast<uint32_t>(memberDecoration.memberIndex)
743  << "-th member of " << structType << " with "
744  << stringifyDecoration(memberDecoration.decoration);
745  }
746  }
747 
749  structType.getStructDecorations(structDecorations);
750 
751  for (spirv::StructType::StructDecorationInfo &structDecoration :
752  structDecorations) {
753  if (failed(processDecorationAttr(loc, resultID,
754  structDecoration.decoration,
755  structDecoration.decorationValue))) {
756  return emitError(loc, "cannot decorate struct ")
757  << structType << " with "
758  << stringifyDecoration(structDecoration.decoration);
759  }
760  }
761 
762  typeEnum = spirv::Opcode::OpTypeStruct;
763 
764  if (structType.isIdentified())
765  serializationCtx.remove(structType.getIdentifier());
766 
767  return success();
768  }
769 
770  if (auto cooperativeMatrixType =
771  dyn_cast<spirv::CooperativeMatrixType>(type)) {
772  uint32_t elementTypeID = 0;
773  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
774  elementTypeID, serializationCtx))) {
775  return failure();
776  }
777  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
778  auto getConstantOp = [&](uint32_t id) {
779  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
780  return prepareConstantInt(loc, attr);
781  };
782  llvm::append_values(
783  operands, elementTypeID,
784  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
785  getConstantOp(cooperativeMatrixType.getRows()),
786  getConstantOp(cooperativeMatrixType.getColumns()),
787  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
788  return success();
789  }
790 
791  if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
792  uint32_t elementTypeID = 0;
793  if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
794  serializationCtx))) {
795  return failure();
796  }
797  typeEnum = spirv::Opcode::OpTypeMatrix;
798  llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
799  return success();
800  }
801 
802  if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
803  uint32_t elementTypeID = 0;
804  uint32_t rank = 0;
805  uint32_t shapeID = 0;
806  uint32_t rankID = 0;
807  if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
808  elementTypeID, serializationCtx))) {
809  return failure();
810  }
811  if (tensorArmType.hasRank()) {
812  ArrayRef<int64_t> dims = tensorArmType.getShape();
813  rank = dims.size();
814  rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
815  if (rankID == 0) {
816  return failure();
817  }
818 
819  bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
820  if (rank > 0 && shaped) {
821  auto I32Type = IntegerType::get(type.getContext(), 32);
822  auto shapeType = ArrayType::get(I32Type, rank);
823  if (rank == 1) {
824  SmallVector<uint64_t, 1> index(rank);
825  shapeID = prepareDenseElementsConstant(
826  loc, shapeType,
827  mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
828  index);
829  } else {
830  shapeID = prepareArrayConstant(
831  loc, shapeType,
832  mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
833  }
834  if (shapeID == 0) {
835  return failure();
836  }
837  }
838  }
839  typeEnum = spirv::Opcode::OpTypeTensorARM;
840  operands.push_back(elementTypeID);
841  if (rankID == 0)
842  return success();
843  operands.push_back(rankID);
844  if (shapeID == 0)
845  return success();
846  operands.push_back(shapeID);
847  return success();
848  }
849 
850  // TODO: Handle other types.
851  return emitError(loc, "unhandled type in serialization: ") << type;
852 }
853 
854 LogicalResult
855 Serializer::prepareFunctionType(Location loc, FunctionType type,
856  spirv::Opcode &typeEnum,
857  SmallVectorImpl<uint32_t> &operands) {
858  typeEnum = spirv::Opcode::OpTypeFunction;
859  assert(type.getNumResults() <= 1 &&
860  "serialization supports only a single return value");
861  uint32_t resultID = 0;
862  if (failed(processType(
863  loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
864  resultID))) {
865  return failure();
866  }
867  operands.push_back(resultID);
868  for (auto &res : type.getInputs()) {
869  uint32_t argTypeID = 0;
870  if (failed(processType(loc, res, argTypeID))) {
871  return failure();
872  }
873  operands.push_back(argTypeID);
874  }
875  return success();
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // Constant
880 //===----------------------------------------------------------------------===//
881 
882 uint32_t Serializer::prepareConstant(Location loc, Type constType,
883  Attribute valueAttr) {
884  if (auto id = prepareConstantScalar(loc, valueAttr)) {
885  return id;
886  }
887 
888  // This is a composite literal. We need to handle each component separately
889  // and then emit an OpConstantComposite for the whole.
890 
891  if (auto id = getConstantID(valueAttr)) {
892  return id;
893  }
894 
895  uint32_t typeID = 0;
896  if (failed(processType(loc, constType, typeID))) {
897  return 0;
898  }
899 
900  uint32_t resultID = 0;
901  if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
902  int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
903  SmallVector<uint64_t, 4> index(rank);
904  resultID = prepareDenseElementsConstant(loc, constType, attr,
905  /*dim=*/0, index);
906  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
907  resultID = prepareArrayConstant(loc, constType, arrayAttr);
908  }
909 
910  if (resultID == 0) {
911  emitError(loc, "cannot serialize attribute: ") << valueAttr;
912  return 0;
913  }
914 
915  constIDMap[valueAttr] = resultID;
916  return resultID;
917 }
918 
919 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
920  ArrayAttr attr) {
921  uint32_t typeID = 0;
922  if (failed(processType(loc, constType, typeID))) {
923  return 0;
924  }
925 
926  uint32_t resultID = getNextID();
927  SmallVector<uint32_t, 4> operands = {typeID, resultID};
928  operands.reserve(attr.size() + 2);
929  auto elementType = cast<spirv::ArrayType>(constType).getElementType();
930  for (Attribute elementAttr : attr) {
931  if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
932  operands.push_back(elementID);
933  } else {
934  return 0;
935  }
936  }
937  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
938  encodeInstructionInto(typesGlobalValues, opcode, operands);
939 
940  return resultID;
941 }
942 
943 // TODO: Turn the below function into iterative function, instead of
944 // recursive function.
945 uint32_t
946 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
947  DenseElementsAttr valueAttr, int dim,
949  auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
950  assert(dim <= shapedType.getRank());
951  if (shapedType.getRank() == dim) {
952  if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
953  return attr.getType().getElementType().isInteger(1)
954  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
955  : prepareConstantInt(loc,
956  attr.getValues<IntegerAttr>()[index]);
957  }
958  if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
959  return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
960  }
961  return 0;
962  }
963 
964  uint32_t typeID = 0;
965  if (failed(processType(loc, constType, typeID))) {
966  return 0;
967  }
968 
969  int64_t numberOfConstituents = shapedType.getDimSize(dim);
970  uint32_t resultID = getNextID();
971  SmallVector<uint32_t, 4> operands = {typeID, resultID};
972  auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
973  if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
974  ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
975  if (!innerShape.empty())
976  elementType = spirv::TensorArmType::get(innerShape, elementType);
977  }
978 
979  // "If the Result Type is a cooperative matrix type, then there must be only
980  // one Constituent, with scalar type matching the cooperative matrix Component
981  // Type, and all components of the matrix are initialized to that value."
982  // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
983  if (isa<spirv::CooperativeMatrixType>(constType)) {
984  if (!valueAttr.isSplat()) {
985  emitError(
986  loc,
987  "cannot serialize a non-splat value for a cooperative matrix type");
988  return 0;
989  }
990  // numberOfConstituents is 1, so we only need one more elements in the
991  // SmallVector, so the total is 3 (1 + 2).
992  operands.reserve(3);
993  // We set dim directly to `shapedType.getRank()` so the recursive call
994  // directly returns the scalar type.
995  if (auto elementID = prepareDenseElementsConstant(
996  loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
997  operands.push_back(elementID);
998  } else {
999  return 0;
1000  }
1001  } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
1002  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1003  {typeID, resultID});
1004  return resultID;
1005  } else {
1006  operands.reserve(numberOfConstituents + 2);
1007  for (int i = 0; i < numberOfConstituents; ++i) {
1008  index[dim] = i;
1009  if (auto elementID = prepareDenseElementsConstant(
1010  loc, elementType, valueAttr, dim + 1, index)) {
1011  operands.push_back(elementID);
1012  } else {
1013  return 0;
1014  }
1015  }
1016  }
1017  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1018  encodeInstructionInto(typesGlobalValues, opcode, operands);
1019 
1020  return resultID;
1021 }
1022 
1023 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1024  bool isSpec) {
1025  if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1026  return prepareConstantFp(loc, floatAttr, isSpec);
1027  }
1028  if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1029  return prepareConstantBool(loc, boolAttr, isSpec);
1030  }
1031  if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1032  return prepareConstantInt(loc, intAttr, isSpec);
1033  }
1034 
1035  return 0;
1036 }
1037 
1038 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1039  bool isSpec) {
1040  if (!isSpec) {
1041  // We can de-duplicate normal constants, but not specialization constants.
1042  if (auto id = getConstantID(boolAttr)) {
1043  return id;
1044  }
1045  }
1046 
1047  // Process the type for this bool literal
1048  uint32_t typeID = 0;
1049  if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
1050  return 0;
1051  }
1052 
1053  auto resultID = getNextID();
1054  auto opcode = boolAttr.getValue()
1055  ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1056  : spirv::Opcode::OpConstantTrue)
1057  : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1058  : spirv::Opcode::OpConstantFalse);
1059  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1060 
1061  if (!isSpec) {
1062  constIDMap[boolAttr] = resultID;
1063  }
1064  return resultID;
1065 }
1066 
1067 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1068  bool isSpec) {
1069  if (!isSpec) {
1070  // We can de-duplicate normal constants, but not specialization constants.
1071  if (auto id = getConstantID(intAttr)) {
1072  return id;
1073  }
1074  }
1075 
1076  // Process the type for this integer literal
1077  uint32_t typeID = 0;
1078  if (failed(processType(loc, intAttr.getType(), typeID))) {
1079  return 0;
1080  }
1081 
1082  auto resultID = getNextID();
1083  APInt value = intAttr.getValue();
1084  unsigned bitwidth = value.getBitWidth();
1085  bool isSigned = intAttr.getType().isSignedInteger();
1086  auto opcode =
1087  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1088 
1089  switch (bitwidth) {
1090  // According to SPIR-V spec, "When the type's bit width is less than
1091  // 32-bits, the literal's value appears in the low-order bits of the word,
1092  // and the high-order bits must be 0 for a floating-point type, or 0 for an
1093  // integer type with Signedness of 0, or sign extended when Signedness
1094  // is 1."
1095  case 32:
1096  case 16:
1097  case 8: {
1098  uint32_t word = 0;
1099  if (isSigned) {
1100  word = static_cast<int32_t>(value.getSExtValue());
1101  } else {
1102  word = static_cast<uint32_t>(value.getZExtValue());
1103  }
1104  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1105  } break;
1106  // According to SPIR-V spec: "When the type's bit width is larger than one
1107  // word, the literal’s low-order words appear first."
1108  case 64: {
1109  struct DoubleWord {
1110  uint32_t word1;
1111  uint32_t word2;
1112  } words;
1113  if (isSigned) {
1114  words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1115  } else {
1116  words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1117  }
1118  encodeInstructionInto(typesGlobalValues, opcode,
1119  {typeID, resultID, words.word1, words.word2});
1120  } break;
1121  default: {
1122  std::string valueStr;
1123  llvm::raw_string_ostream rss(valueStr);
1124  value.print(rss, /*isSigned=*/false);
1125 
1126  emitError(loc, "cannot serialize ")
1127  << bitwidth << "-bit integer literal: " << valueStr;
1128  return 0;
1129  }
1130  }
1131 
1132  if (!isSpec) {
1133  constIDMap[intAttr] = resultID;
1134  }
1135  return resultID;
1136 }
1137 
1138 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1139  bool isSpec) {
1140  if (!isSpec) {
1141  // We can de-duplicate normal constants, but not specialization constants.
1142  if (auto id = getConstantID(floatAttr)) {
1143  return id;
1144  }
1145  }
1146 
1147  // Process the type for this float literal
1148  uint32_t typeID = 0;
1149  if (failed(processType(loc, floatAttr.getType(), typeID))) {
1150  return 0;
1151  }
1152 
1153  auto resultID = getNextID();
1154  APFloat value = floatAttr.getValue();
1155  const llvm::fltSemantics *semantics = &value.getSemantics();
1156 
1157  auto opcode =
1158  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1159 
1160  if (semantics == &APFloat::IEEEsingle()) {
1161  uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1162  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1163  } else if (semantics == &APFloat::IEEEdouble()) {
1164  struct DoubleWord {
1165  uint32_t word1;
1166  uint32_t word2;
1167  } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1168  encodeInstructionInto(typesGlobalValues, opcode,
1169  {typeID, resultID, words.word1, words.word2});
1170  } else if (semantics == &APFloat::IEEEhalf() ||
1171  semantics == &APFloat::BFloat()) {
1172  uint32_t word =
1173  static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1174  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1175  } else {
1176  std::string valueStr;
1177  llvm::raw_string_ostream rss(valueStr);
1178  value.print(rss);
1179 
1180  emitError(loc, "cannot serialize ")
1181  << floatAttr.getType() << "-typed float literal: " << valueStr;
1182  return 0;
1183  }
1184 
1185  if (!isSpec) {
1186  constIDMap[floatAttr] = resultID;
1187  }
1188  return resultID;
1189 }
1190 
1191 // Returns type of attribute. In case of a TypedAttr this will simply return
1192 // the type. But for an ArrayAttr which is untyped and can be multidimensional
1193 // it creates the ArrayType recursively.
1195  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1196  return typedAttr.getType();
1197  }
1198 
1199  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1200  return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1201  }
1202 
1203  return nullptr;
1204 }
1205 
1206 uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1207  Type resultType,
1208  Attribute valueAttr) {
1209  std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1210  if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1211  return id;
1212  }
1213 
1214  uint32_t typeID = 0;
1215  if (failed(processType(loc, resultType, typeID))) {
1216  return 0;
1217  }
1218 
1219  Type valueType = getValueType(valueAttr);
1220  if (!valueAttr)
1221  return 0;
1222 
1223  auto compositeType = dyn_cast<CompositeType>(resultType);
1224  if (!compositeType)
1225  return 0;
1226  Type elementType = compositeType.getElementType(0);
1227 
1228  uint32_t constandID;
1229  if (elementType == valueType) {
1230  constandID = prepareConstant(loc, elementType, valueAttr);
1231  } else {
1232  constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1233  }
1234 
1235  uint32_t resultID = getNextID();
1236  if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1237  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1238  {typeID, resultID});
1239  } else {
1240  encodeInstructionInto(typesGlobalValues,
1241  spirv::Opcode::OpConstantCompositeReplicateEXT,
1242  {typeID, resultID, constandID});
1243  }
1244 
1245  constCompositeReplicateIDMap[valueTypePair] = resultID;
1246  return resultID;
1247 }
1248 
1249 //===----------------------------------------------------------------------===//
1250 // Control flow
1251 //===----------------------------------------------------------------------===//
1252 
1253 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1254  if (uint32_t id = getBlockID(block))
1255  return id;
1256  return blockIDMap[block] = getNextID();
1257 }
1258 
1259 #ifndef NDEBUG
1260 void Serializer::printBlock(Block *block, raw_ostream &os) {
1261  os << "block " << block << " (id = ";
1262  if (uint32_t id = getBlockID(block))
1263  os << id;
1264  else
1265  os << "unknown";
1266  os << ")\n";
1267 }
1268 #endif
1269 
1270 LogicalResult
1271 Serializer::processBlock(Block *block, bool omitLabel,
1272  function_ref<LogicalResult()> emitMerge) {
1273  LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1274  LLVM_DEBUG(block->print(llvm::dbgs()));
1275  LLVM_DEBUG(llvm::dbgs() << '\n');
1276  if (!omitLabel) {
1277  uint32_t blockID = getOrCreateBlockID(block);
1278  LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1279 
1280  // Emit OpLabel for this block.
1281  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1282  }
1283 
1284  // Emit OpPhi instructions for block arguments, if any.
1285  if (failed(emitPhiForBlockArguments(block)))
1286  return failure();
1287 
1288  // If we need to emit merge instructions, it must happen in this block. Check
1289  // whether we have other structured control flow ops, which will be expanded
1290  // into multiple basic blocks. If that's the case, we need to emit the merge
1291  // right now and then create new blocks for further serialization of the ops
1292  // in this block.
1293  if (emitMerge &&
1294  llvm::any_of(block->getOperations(),
1295  llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1296  if (failed(emitMerge()))
1297  return failure();
1298  emitMerge = nullptr;
1299 
1300  // Start a new block for further serialization.
1301  uint32_t blockID = getNextID();
1302  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1303  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1304  }
1305 
1306  // Process each op in this block except the terminator.
1307  for (Operation &op : llvm::drop_end(*block)) {
1308  if (failed(processOperation(&op)))
1309  return failure();
1310  }
1311 
1312  // Process the terminator.
1313  if (emitMerge)
1314  if (failed(emitMerge()))
1315  return failure();
1316  if (failed(processOperation(&block->back())))
1317  return failure();
1318 
1319  return success();
1320 }
1321 
1322 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1323  // Nothing to do if this block has no arguments or it's the entry block, which
1324  // always has the same arguments as the function signature.
1325  if (block->args_empty() || block->isEntryBlock())
1326  return success();
1327 
1328  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1329 
1330  // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1331  // A SPIR-V OpPhi instruction is of the syntax:
1332  // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1333  // So we need to collect all predecessor blocks and the arguments they send
1334  // to this block.
1336  for (Block *mlirPredecessor : block->getPredecessors()) {
1337  auto *terminator = mlirPredecessor->getTerminator();
1338  LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1339  LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1340  LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1341  // The predecessor here is the immediate one according to MLIR's IR
1342  // structure. It does not directly map to the incoming parent block for the
1343  // OpPhi instructions at SPIR-V binary level. This is because structured
1344  // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1345  // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1346  // the branch op jumping to the OpPhi's block then resides in the previous
1347  // structured control flow op's merge block.
1348  Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1349  LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1350  LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1351  if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1352  predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1353  } else if (auto branchCondOp =
1354  dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1355  std::optional<OperandRange> blockOperands;
1356  if (branchCondOp.getTrueTarget() == block) {
1357  blockOperands = branchCondOp.getTrueTargetOperands();
1358  } else {
1359  assert(branchCondOp.getFalseTarget() == block);
1360  blockOperands = branchCondOp.getFalseTargetOperands();
1361  }
1362 
1363  assert(!blockOperands->empty() &&
1364  "expected non-empty block operand range");
1365  predecessors.emplace_back(spirvPredecessor, *blockOperands);
1366  } else {
1367  return terminator->emitError("unimplemented terminator for Phi creation");
1368  }
1369  LLVM_DEBUG({
1370  llvm::dbgs() << " block arguments:\n";
1371  for (Value v : predecessors.back().second)
1372  llvm::dbgs() << " " << v << "\n";
1373  });
1374  }
1375 
1376  // Then create OpPhi instruction for each of the block argument.
1377  for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1378  BlockArgument arg = block->getArgument(argIndex);
1379 
1380  // Get the type <id> and result <id> for this OpPhi instruction.
1381  uint32_t phiTypeID = 0;
1382  if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1383  return failure();
1384  uint32_t phiID = getNextID();
1385 
1386  LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1387  << arg << " (id = " << phiID << ")\n");
1388 
1389  // Prepare the (value <id>, parent block <id>) pairs.
1390  SmallVector<uint32_t, 8> phiArgs;
1391  phiArgs.push_back(phiTypeID);
1392  phiArgs.push_back(phiID);
1393 
1394  for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1395  Value value = predecessors[predIndex].second[argIndex];
1396  uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1397  LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1398  << ") value " << value << ' ');
1399  // Each pair is a value <id> ...
1400  uint32_t valueId = getValueID(value);
1401  if (valueId == 0) {
1402  // The op generating this value hasn't been visited yet so we don't have
1403  // an <id> assigned yet. Record this to fix up later.
1404  LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1405  deferredPhiValues[value].push_back(functionBody.size() + 1 +
1406  phiArgs.size());
1407  } else {
1408  LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1409  }
1410  phiArgs.push_back(valueId);
1411  // ... and a parent block <id>.
1412  phiArgs.push_back(predBlockId);
1413  }
1414 
1415  encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1416  valueIDMap[arg] = phiID;
1417  }
1418 
1419  return success();
1420 }
1421 
1422 //===----------------------------------------------------------------------===//
1423 // Operation
1424 //===----------------------------------------------------------------------===//
1425 
1426 LogicalResult Serializer::encodeExtensionInstruction(
1427  Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1428  ArrayRef<uint32_t> operands) {
1429  // Check if the extension has been imported.
1430  auto &setID = extendedInstSetIDMap[extensionSetName];
1431  if (!setID) {
1432  setID = getNextID();
1433  SmallVector<uint32_t, 16> importOperands;
1434  importOperands.push_back(setID);
1435  spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1436  encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1437  importOperands);
1438  }
1439 
1440  // The first two operands are the result type <id> and result <id>. The set
1441  // <id> and the opcode need to be insert after this.
1442  if (operands.size() < 2) {
1443  return op->emitError("extended instructions must have a result encoding");
1444  }
1445  SmallVector<uint32_t, 8> extInstOperands;
1446  extInstOperands.reserve(operands.size() + 2);
1447  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1448  extInstOperands.push_back(setID);
1449  extInstOperands.push_back(extensionOpcode);
1450  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1451  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1452  extInstOperands);
1453  return success();
1454 }
1455 
1456 LogicalResult Serializer::processOperation(Operation *opInst) {
1457  LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1458 
1459  // First dispatch the ops that do not directly mirror an instruction from
1460  // the SPIR-V spec.
1462  .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1463  .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1464  .Case([&](spirv::BranchConditionalOp op) {
1465  return processBranchConditionalOp(op);
1466  })
1467  .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1468  .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1469  return processConstantCompositeReplicateOp(op);
1470  })
1471  .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1472  .Case([&](spirv::GlobalVariableOp op) {
1473  return processGlobalVariableOp(op);
1474  })
1475  .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1476  .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1477  .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1478  .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1479  .Case([&](spirv::SpecConstantCompositeOp op) {
1480  return processSpecConstantCompositeOp(op);
1481  })
1482  .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1483  return processSpecConstantCompositeReplicateOp(op);
1484  })
1485  .Case([&](spirv::SpecConstantOperationOp op) {
1486  return processSpecConstantOperationOp(op);
1487  })
1488  .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1489  .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1490 
1491  // Then handle all the ops that directly mirror SPIR-V instructions with
1492  // auto-generated methods.
1493  .Default(
1494  [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1495 }
1496 
1497 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1498  StringRef extInstSet,
1499  uint32_t opcode) {
1500  SmallVector<uint32_t, 4> operands;
1501  Location loc = op->getLoc();
1502 
1503  uint32_t resultID = 0;
1504  if (op->getNumResults() != 0) {
1505  uint32_t resultTypeID = 0;
1506  if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1507  return failure();
1508  operands.push_back(resultTypeID);
1509 
1510  resultID = getNextID();
1511  operands.push_back(resultID);
1512  valueIDMap[op->getResult(0)] = resultID;
1513  };
1514 
1515  for (Value operand : op->getOperands())
1516  operands.push_back(getValueID(operand));
1517 
1518  if (failed(emitDebugLine(functionBody, loc)))
1519  return failure();
1520 
1521  if (extInstSet.empty()) {
1522  encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1523  operands);
1524  } else {
1525  if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1526  return failure();
1527  }
1528 
1529  if (op->getNumResults() != 0) {
1530  for (auto attr : op->getAttrs()) {
1531  if (failed(processDecoration(loc, resultID, attr)))
1532  return failure();
1533  }
1534  }
1535 
1536  return success();
1537 }
1538 
1539 LogicalResult Serializer::emitDecoration(uint32_t target,
1540  spirv::Decoration decoration,
1541  ArrayRef<uint32_t> params) {
1542  uint32_t wordCount = 3 + params.size();
1543  llvm::append_values(
1544  decorations,
1545  spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1546  static_cast<uint32_t>(decoration));
1547  llvm::append_range(decorations, params);
1548  return success();
1549 }
1550 
1551 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1552  Location loc) {
1553  if (!options.emitDebugInfo)
1554  return success();
1555 
1556  if (lastProcessedWasMergeInst) {
1557  lastProcessedWasMergeInst = false;
1558  return success();
1559  }
1560 
1561  auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1562  if (fileLoc)
1563  encodeInstructionInto(binary, spirv::Opcode::OpLine,
1564  {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1565  return success();
1566 }
1567 } // namespace spirv
1568 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static bool isZeroValue(Attribute attr)
Definition: Serializer.cpp:73
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
Definition: Serializer.cpp:36
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Definition: Serializer.cpp:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Definition: Attributes.cpp:37
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:240
OpListType & getOperations()
Definition: Block.h:137
void print(raw_ostream &os)
bool args_empty()
Definition: Block.h:99
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:195
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:271
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
Definition: Builders.cpp:174
An attribute that represents a reference to a dense vector or tensor object.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:66
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:514
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Definition: Serializer.cpp:160
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
Definition: Serializer.cpp:104
LogicalResult serialize()
Serializes the remembered SPIR-V module.
Definition: Serializer.cpp:108
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
Definition: Serializer.cpp:134
SPIR-V struct type.
Definition: SPIRVTypes.h:295
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
Definition: Serializer.cpp:262
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:97
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Definition: Serializer.cpp:244
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
Definition: Serialization.h:28
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Definition: Serialization.h:30