MLIR  20.0.0git
Serializer.cpp
Go to the documentation of this file.
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
27 #include <cstdint>
28 #include <optional>
29 
30 #define DEBUG_TYPE "spirv-serialization"
31 
32 using namespace mlir;
33 
34 /// Returns the merge block if the given `op` is a structured control flow op.
35 /// Otherwise returns nullptr.
37  if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38  return selectionOp.getMergeBlock();
39  if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40  return loopOp.getMergeBlock();
41  return nullptr;
42 }
43 
44 /// Given a predecessor `block` for a block with arguments, returns the block
45 /// that should be used as the parent block for SPIR-V OpPhi instructions
46 /// corresponding to the block arguments.
47 static Block *getPhiIncomingBlock(Block *block) {
48  // If the predecessor block in question is the entry block for a
49  // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50  if (block->isEntryBlock()) {
51  if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52  // Then the incoming parent block for OpPhi should be the merge block of
53  // the structured control flow op before this loop.
54  Operation *op = loopOp.getOperation();
55  while ((op = op->getPrevNode()) != nullptr)
56  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57  return incomingBlock;
58  // Or the enclosing block itself if no structured control flow ops
59  // exists before this loop.
60  return loopOp->getBlock();
61  }
62  }
63 
64  // Otherwise, we jump from the given predecessor block. Try to see if there is
65  // a structured control flow op inside it.
66  for (Operation &op : llvm::reverse(block->getOperations())) {
67  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68  return incomingBlock;
69  }
70  return block;
71 }
72 
73 namespace mlir {
74 namespace spirv {
75 
76 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77 /// the given `binary` vector.
78 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79  ArrayRef<uint32_t> operands) {
80  uint32_t wordCount = 1 + operands.size();
81  binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
82  binary.append(operands.begin(), operands.end());
83 }
84 
85 Serializer::Serializer(spirv::ModuleOp module,
87  : module(module), mlirBuilder(module.getContext()), options(options) {}
88 
89 LogicalResult Serializer::serialize() {
90  LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91 
92  if (failed(module.verifyInvariants()))
93  return failure();
94 
95  // TODO: handle the other sections
96  processCapability();
97  processExtension();
98  processMemoryModel();
99  processDebugInfo();
100 
101  // Iterate over the module body to serialize it. Assumptions are that there is
102  // only one basic block in the moduleOp
103  for (auto &op : *module.getBody()) {
104  if (failed(processOperation(&op))) {
105  return failure();
106  }
107  }
108 
109  LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110  return success();
111 }
112 
114  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115  extensions.size() + extendedSets.size() +
116  memoryModel.size() + entryPoints.size() +
117  executionModes.size() + decorations.size() +
118  typesGlobalValues.size() + functions.size();
119 
120  binary.clear();
121  binary.reserve(moduleSize);
122 
123  spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
124  nextID);
125  binary.append(capabilities.begin(), capabilities.end());
126  binary.append(extensions.begin(), extensions.end());
127  binary.append(extendedSets.begin(), extendedSets.end());
128  binary.append(memoryModel.begin(), memoryModel.end());
129  binary.append(entryPoints.begin(), entryPoints.end());
130  binary.append(executionModes.begin(), executionModes.end());
131  binary.append(debug.begin(), debug.end());
132  binary.append(names.begin(), names.end());
133  binary.append(decorations.begin(), decorations.end());
134  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135  binary.append(functions.begin(), functions.end());
136 }
137 
138 #ifndef NDEBUG
139 void Serializer::printValueIDMap(raw_ostream &os) {
140  os << "\n= Value <id> Map =\n\n";
141  for (auto valueIDPair : valueIDMap) {
142  Value val = valueIDPair.first;
143  os << " " << val << " "
144  << "id = " << valueIDPair.second << ' ';
145  if (auto *op = val.getDefiningOp()) {
146  os << "from op '" << op->getName() << "'";
147  } else if (auto arg = dyn_cast<BlockArgument>(val)) {
148  Block *block = arg.getOwner();
149  os << "from argument of block " << block << ' ';
150  os << " in op '" << block->getParentOp()->getName() << "'";
151  }
152  os << '\n';
153  }
154 }
155 #endif
156 
157 //===----------------------------------------------------------------------===//
158 // Module structure
159 //===----------------------------------------------------------------------===//
160 
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162  auto funcID = funcIDMap.lookup(fnName);
163  if (!funcID) {
164  funcID = getNextID();
165  funcIDMap[fnName] = funcID;
166  }
167  return funcID;
168 }
169 
170 void Serializer::processCapability() {
171  for (auto cap : module.getVceTriple()->getCapabilities())
172  encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
173  {static_cast<uint32_t>(cap)});
174 }
175 
176 void Serializer::processDebugInfo() {
177  if (!options.emitDebugInfo)
178  return;
179  auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180  auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181  fileID = getNextID();
182  SmallVector<uint32_t, 16> operands;
183  operands.push_back(fileID);
184  spirv::encodeStringLiteralInto(operands, fileName);
185  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
186  // TODO: Encode more debug instructions.
187 }
188 
189 void Serializer::processExtension() {
191  for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192  extName.clear();
193  spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
194  encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
195  }
196 }
197 
198 void Serializer::processMemoryModel() {
199  StringAttr memoryModelName = module.getMemoryModelAttrName();
200  auto mm = static_cast<uint32_t>(
201  module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
202  .getValue());
203 
204  StringAttr addressingModelName = module.getAddressingModelAttrName();
205  auto am = static_cast<uint32_t>(
206  module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
207  .getValue());
208 
209  encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
210 }
211 
212 static std::string getDecorationName(StringRef attrName) {
213  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
214  // expected FPFastMathMode.
215  if (attrName == "fp_fast_math_mode")
216  return "FPFastMathMode";
217 
218  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
219 }
220 
221 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
222  Decoration decoration,
223  Attribute attr) {
225  switch (decoration) {
226  case spirv::Decoration::LinkageAttributes: {
227  // Get the value of the Linkage Attributes
228  // e.g., LinkageAttributes=["linkageName", linkageType].
229  auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
230  auto linkageName = linkageAttr.getLinkageName();
231  auto linkageType = linkageAttr.getLinkageType().getValue();
232  // Encode the Linkage Name (string literal to uint32_t).
233  spirv::encodeStringLiteralInto(args, linkageName);
234  // Encode LinkageType & Add the Linkagetype to the args.
235  args.push_back(static_cast<uint32_t>(linkageType));
236  break;
237  }
238  case spirv::Decoration::FPFastMathMode:
239  if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
240  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
241  break;
242  }
243  return emitError(loc, "expected FPFastMathModeAttr attribute for ")
244  << stringifyDecoration(decoration);
245  case spirv::Decoration::Binding:
246  case spirv::Decoration::DescriptorSet:
247  case spirv::Decoration::Location:
248  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
249  args.push_back(intAttr.getValue().getZExtValue());
250  break;
251  }
252  return emitError(loc, "expected integer attribute for ")
253  << stringifyDecoration(decoration);
254  case spirv::Decoration::BuiltIn:
255  if (auto strAttr = dyn_cast<StringAttr>(attr)) {
256  auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
257  if (enumVal) {
258  args.push_back(static_cast<uint32_t>(*enumVal));
259  break;
260  }
261  return emitError(loc, "invalid ")
262  << stringifyDecoration(decoration) << " decoration attribute "
263  << strAttr.getValue();
264  }
265  return emitError(loc, "expected string attribute for ")
266  << stringifyDecoration(decoration);
267  case spirv::Decoration::Aliased:
268  case spirv::Decoration::AliasedPointer:
269  case spirv::Decoration::Flat:
270  case spirv::Decoration::NonReadable:
271  case spirv::Decoration::NonWritable:
272  case spirv::Decoration::NoPerspective:
273  case spirv::Decoration::NoSignedWrap:
274  case spirv::Decoration::NoUnsignedWrap:
275  case spirv::Decoration::RelaxedPrecision:
276  case spirv::Decoration::Restrict:
277  case spirv::Decoration::RestrictPointer:
278  case spirv::Decoration::NoContraction:
279  // For unit attributes and decoration attributes, the args list
280  // has no values so we do nothing.
281  if (isa<UnitAttr, DecorationAttr>(attr))
282  break;
283  return emitError(loc,
284  "expected unit attribute or decoration attribute for ")
285  << stringifyDecoration(decoration);
286  default:
287  return emitError(loc, "unhandled decoration ")
288  << stringifyDecoration(decoration);
289  }
290  return emitDecoration(resultID, decoration, args);
291 }
292 
293 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
294  NamedAttribute attr) {
295  StringRef attrName = attr.getName().strref();
296  std::string decorationName = getDecorationName(attrName);
297  std::optional<Decoration> decoration =
298  spirv::symbolizeDecoration(decorationName);
299  if (!decoration) {
300  return emitError(
301  loc, "non-argument attributes expected to have snake-case-ified "
302  "decoration name, unhandled attribute with name : ")
303  << attrName;
304  }
305  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
306 }
307 
308 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
309  assert(!name.empty() && "unexpected empty string for OpName");
310  if (!options.emitSymbolName)
311  return success();
312 
313  SmallVector<uint32_t, 4> nameOperands;
314  nameOperands.push_back(resultID);
315  spirv::encodeStringLiteralInto(nameOperands, name);
316  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
317  return success();
318 }
319 
320 template <>
321 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
322  Location loc, spirv::ArrayType type, uint32_t resultID) {
323  if (unsigned stride = type.getArrayStride()) {
324  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
325  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
326  }
327  return success();
328 }
329 
330 template <>
331 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
332  Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
333  if (unsigned stride = type.getArrayStride()) {
334  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
335  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
336  }
337  return success();
338 }
339 
340 LogicalResult Serializer::processMemberDecoration(
341  uint32_t structID,
342  const spirv::StructType::MemberDecorationInfo &memberDecoration) {
344  {structID, memberDecoration.memberIndex,
345  static_cast<uint32_t>(memberDecoration.decoration)});
346  if (memberDecoration.hasValue) {
347  args.push_back(memberDecoration.decorationValue);
348  }
349  encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
350  return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // Type
355 //===----------------------------------------------------------------------===//
356 
357 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
358 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
359 // PushConstant Storage Classes must be explicitly laid out."
360 bool Serializer::isInterfaceStructPtrType(Type type) const {
361  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
362  switch (ptrType.getStorageClass()) {
363  case spirv::StorageClass::PhysicalStorageBuffer:
364  case spirv::StorageClass::PushConstant:
365  case spirv::StorageClass::StorageBuffer:
366  case spirv::StorageClass::Uniform:
367  return isa<spirv::StructType>(ptrType.getPointeeType());
368  default:
369  break;
370  }
371  }
372  return false;
373 }
374 
375 LogicalResult Serializer::processType(Location loc, Type type,
376  uint32_t &typeID) {
377  // Maintains a set of names for nested identified struct types. This is used
378  // to properly serialize recursive references.
379  SetVector<StringRef> serializationCtx;
380  return processTypeImpl(loc, type, typeID, serializationCtx);
381 }
382 
383 LogicalResult
384 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
385  SetVector<StringRef> &serializationCtx) {
386  typeID = getTypeID(type);
387  if (typeID)
388  return success();
389 
390  typeID = getNextID();
391  SmallVector<uint32_t, 4> operands;
392 
393  operands.push_back(typeID);
394  auto typeEnum = spirv::Opcode::OpTypeVoid;
395  bool deferSerialization = false;
396 
397  if ((isa<FunctionType>(type) &&
398  succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
399  operands))) ||
400  succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
401  deferSerialization, serializationCtx))) {
402  if (deferSerialization)
403  return success();
404 
405  typeIDMap[type] = typeID;
406 
407  encodeInstructionInto(typesGlobalValues, typeEnum, operands);
408 
409  if (recursiveStructInfos.count(type) != 0) {
410  // This recursive struct type is emitted already, now the OpTypePointer
411  // instructions referring to recursive references are emitted as well.
412  for (auto &ptrInfo : recursiveStructInfos[type]) {
413  // TODO: This might not work if more than 1 recursive reference is
414  // present in the struct.
415  SmallVector<uint32_t, 4> ptrOperands;
416  ptrOperands.push_back(ptrInfo.pointerTypeID);
417  ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
418  ptrOperands.push_back(typeIDMap[type]);
419 
420  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
421  ptrOperands);
422  }
423 
424  recursiveStructInfos[type].clear();
425  }
426 
427  return success();
428  }
429 
430  return failure();
431 }
432 
433 LogicalResult Serializer::prepareBasicType(
434  Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
435  SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
436  SetVector<StringRef> &serializationCtx) {
437  deferSerialization = false;
438 
439  if (isVoidType(type)) {
440  typeEnum = spirv::Opcode::OpTypeVoid;
441  return success();
442  }
443 
444  if (auto intType = dyn_cast<IntegerType>(type)) {
445  if (intType.getWidth() == 1) {
446  typeEnum = spirv::Opcode::OpTypeBool;
447  return success();
448  }
449 
450  typeEnum = spirv::Opcode::OpTypeInt;
451  operands.push_back(intType.getWidth());
452  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
453  // to preserve or validate.
454  // 0 indicates unsigned, or no signedness semantics
455  // 1 indicates signed semantics."
456  operands.push_back(intType.isSigned() ? 1 : 0);
457  return success();
458  }
459 
460  if (auto floatType = dyn_cast<FloatType>(type)) {
461  typeEnum = spirv::Opcode::OpTypeFloat;
462  operands.push_back(floatType.getWidth());
463  return success();
464  }
465 
466  if (auto vectorType = dyn_cast<VectorType>(type)) {
467  uint32_t elementTypeID = 0;
468  if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
469  serializationCtx))) {
470  return failure();
471  }
472  typeEnum = spirv::Opcode::OpTypeVector;
473  operands.push_back(elementTypeID);
474  operands.push_back(vectorType.getNumElements());
475  return success();
476  }
477 
478  if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
479  typeEnum = spirv::Opcode::OpTypeImage;
480  uint32_t sampledTypeID = 0;
481  if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
482  return failure();
483 
484  llvm::append_values(operands, sampledTypeID,
485  static_cast<uint32_t>(imageType.getDim()),
486  static_cast<uint32_t>(imageType.getDepthInfo()),
487  static_cast<uint32_t>(imageType.getArrayedInfo()),
488  static_cast<uint32_t>(imageType.getSamplingInfo()),
489  static_cast<uint32_t>(imageType.getSamplerUseInfo()),
490  static_cast<uint32_t>(imageType.getImageFormat()));
491  return success();
492  }
493 
494  if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
495  typeEnum = spirv::Opcode::OpTypeArray;
496  uint32_t elementTypeID = 0;
497  if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
498  serializationCtx))) {
499  return failure();
500  }
501  operands.push_back(elementTypeID);
502  if (auto elementCountID = prepareConstantInt(
503  loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
504  operands.push_back(elementCountID);
505  }
506  return processTypeDecoration(loc, arrayType, resultID);
507  }
508 
509  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
510  uint32_t pointeeTypeID = 0;
511  spirv::StructType pointeeStruct =
512  dyn_cast<spirv::StructType>(ptrType.getPointeeType());
513 
514  if (pointeeStruct && pointeeStruct.isIdentified() &&
515  serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
516  // A recursive reference to an enclosing struct is found.
517  //
518  // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
519  // class as operands.
520  SmallVector<uint32_t, 2> forwardPtrOperands;
521  forwardPtrOperands.push_back(resultID);
522  forwardPtrOperands.push_back(
523  static_cast<uint32_t>(ptrType.getStorageClass()));
524 
525  encodeInstructionInto(typesGlobalValues,
526  spirv::Opcode::OpTypeForwardPointer,
527  forwardPtrOperands);
528 
529  // 2. Find the pointee (enclosing) struct.
530  auto structType = spirv::StructType::getIdentified(
531  module.getContext(), pointeeStruct.getIdentifier());
532 
533  if (!structType)
534  return failure();
535 
536  // 3. Mark the OpTypePointer that is supposed to be emitted by this call
537  // as deferred.
538  deferSerialization = true;
539 
540  // 4. Record the info needed to emit the deferred OpTypePointer
541  // instruction when the enclosing struct is completely serialized.
542  recursiveStructInfos[structType].push_back(
543  {resultID, ptrType.getStorageClass()});
544  } else {
545  if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
546  serializationCtx)))
547  return failure();
548  }
549 
550  typeEnum = spirv::Opcode::OpTypePointer;
551  operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
552  operands.push_back(pointeeTypeID);
553 
554  if (isInterfaceStructPtrType(ptrType)) {
555  if (failed(emitDecoration(getTypeID(pointeeStruct),
556  spirv::Decoration::Block)))
557  return emitError(loc, "cannot decorate ")
558  << pointeeStruct << " with Block decoration";
559  }
560 
561  return success();
562  }
563 
564  if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
565  uint32_t elementTypeID = 0;
566  if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
567  elementTypeID, serializationCtx))) {
568  return failure();
569  }
570  typeEnum = spirv::Opcode::OpTypeRuntimeArray;
571  operands.push_back(elementTypeID);
572  return processTypeDecoration(loc, runtimeArrayType, resultID);
573  }
574 
575  if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
576  typeEnum = spirv::Opcode::OpTypeSampledImage;
577  uint32_t imageTypeID = 0;
578  if (failed(
579  processType(loc, sampledImageType.getImageType(), imageTypeID))) {
580  return failure();
581  }
582  operands.push_back(imageTypeID);
583  return success();
584  }
585 
586  if (auto structType = dyn_cast<spirv::StructType>(type)) {
587  if (structType.isIdentified()) {
588  if (failed(processName(resultID, structType.getIdentifier())))
589  return failure();
590  serializationCtx.insert(structType.getIdentifier());
591  }
592 
593  bool hasOffset = structType.hasOffset();
594  for (auto elementIndex :
595  llvm::seq<uint32_t>(0, structType.getNumElements())) {
596  uint32_t elementTypeID = 0;
597  if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
598  elementTypeID, serializationCtx))) {
599  return failure();
600  }
601  operands.push_back(elementTypeID);
602  if (hasOffset) {
603  // Decorate each struct member with an offset
605  elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
606  static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
607  if (failed(processMemberDecoration(resultID, offsetDecoration))) {
608  return emitError(loc, "cannot decorate ")
609  << elementIndex << "-th member of " << structType
610  << " with its offset";
611  }
612  }
613  }
615  structType.getMemberDecorations(memberDecorations);
616 
617  for (auto &memberDecoration : memberDecorations) {
618  if (failed(processMemberDecoration(resultID, memberDecoration))) {
619  return emitError(loc, "cannot decorate ")
620  << static_cast<uint32_t>(memberDecoration.memberIndex)
621  << "-th member of " << structType << " with "
622  << stringifyDecoration(memberDecoration.decoration);
623  }
624  }
625 
626  typeEnum = spirv::Opcode::OpTypeStruct;
627 
628  if (structType.isIdentified())
629  serializationCtx.remove(structType.getIdentifier());
630 
631  return success();
632  }
633 
634  if (auto cooperativeMatrixType =
635  dyn_cast<spirv::CooperativeMatrixType>(type)) {
636  uint32_t elementTypeID = 0;
637  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
638  elementTypeID, serializationCtx))) {
639  return failure();
640  }
641  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
642  auto getConstantOp = [&](uint32_t id) {
643  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
644  return prepareConstantInt(loc, attr);
645  };
646  llvm::append_values(
647  operands, elementTypeID,
648  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
649  getConstantOp(cooperativeMatrixType.getRows()),
650  getConstantOp(cooperativeMatrixType.getColumns()),
651  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
652  return success();
653  }
654 
655  if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
656  uint32_t elementTypeID = 0;
657  if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
658  elementTypeID, serializationCtx))) {
659  return failure();
660  }
661  typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
662  auto getConstantOp = [&](uint32_t id) {
663  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
664  return prepareConstantInt(loc, attr);
665  };
666  llvm::append_values(
667  operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
668  getConstantOp(jointMatrixType.getColumns()),
669  getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
670  getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
671  return success();
672  }
673 
674  if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
675  uint32_t elementTypeID = 0;
676  if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
677  serializationCtx))) {
678  return failure();
679  }
680  typeEnum = spirv::Opcode::OpTypeMatrix;
681  llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
682  return success();
683  }
684 
685  // TODO: Handle other types.
686  return emitError(loc, "unhandled type in serialization: ") << type;
687 }
688 
689 LogicalResult
690 Serializer::prepareFunctionType(Location loc, FunctionType type,
691  spirv::Opcode &typeEnum,
692  SmallVectorImpl<uint32_t> &operands) {
693  typeEnum = spirv::Opcode::OpTypeFunction;
694  assert(type.getNumResults() <= 1 &&
695  "serialization supports only a single return value");
696  uint32_t resultID = 0;
697  if (failed(processType(
698  loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
699  resultID))) {
700  return failure();
701  }
702  operands.push_back(resultID);
703  for (auto &res : type.getInputs()) {
704  uint32_t argTypeID = 0;
705  if (failed(processType(loc, res, argTypeID))) {
706  return failure();
707  }
708  operands.push_back(argTypeID);
709  }
710  return success();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // Constant
715 //===----------------------------------------------------------------------===//
716 
717 uint32_t Serializer::prepareConstant(Location loc, Type constType,
718  Attribute valueAttr) {
719  if (auto id = prepareConstantScalar(loc, valueAttr)) {
720  return id;
721  }
722 
723  // This is a composite literal. We need to handle each component separately
724  // and then emit an OpConstantComposite for the whole.
725 
726  if (auto id = getConstantID(valueAttr)) {
727  return id;
728  }
729 
730  uint32_t typeID = 0;
731  if (failed(processType(loc, constType, typeID))) {
732  return 0;
733  }
734 
735  uint32_t resultID = 0;
736  if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
737  int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
738  SmallVector<uint64_t, 4> index(rank);
739  resultID = prepareDenseElementsConstant(loc, constType, attr,
740  /*dim=*/0, index);
741  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
742  resultID = prepareArrayConstant(loc, constType, arrayAttr);
743  }
744 
745  if (resultID == 0) {
746  emitError(loc, "cannot serialize attribute: ") << valueAttr;
747  return 0;
748  }
749 
750  constIDMap[valueAttr] = resultID;
751  return resultID;
752 }
753 
754 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
755  ArrayAttr attr) {
756  uint32_t typeID = 0;
757  if (failed(processType(loc, constType, typeID))) {
758  return 0;
759  }
760 
761  uint32_t resultID = getNextID();
762  SmallVector<uint32_t, 4> operands = {typeID, resultID};
763  operands.reserve(attr.size() + 2);
764  auto elementType = cast<spirv::ArrayType>(constType).getElementType();
765  for (Attribute elementAttr : attr) {
766  if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
767  operands.push_back(elementID);
768  } else {
769  return 0;
770  }
771  }
772  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
773  encodeInstructionInto(typesGlobalValues, opcode, operands);
774 
775  return resultID;
776 }
777 
778 // TODO: Turn the below function into iterative function, instead of
779 // recursive function.
780 uint32_t
781 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
782  DenseElementsAttr valueAttr, int dim,
784  auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
785  assert(dim <= shapedType.getRank());
786  if (shapedType.getRank() == dim) {
787  if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
788  return attr.getType().getElementType().isInteger(1)
789  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
790  : prepareConstantInt(loc,
791  attr.getValues<IntegerAttr>()[index]);
792  }
793  if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
794  return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
795  }
796  return 0;
797  }
798 
799  uint32_t typeID = 0;
800  if (failed(processType(loc, constType, typeID))) {
801  return 0;
802  }
803 
804  uint32_t resultID = getNextID();
805  SmallVector<uint32_t, 4> operands = {typeID, resultID};
806  operands.reserve(shapedType.getDimSize(dim) + 2);
807  auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
808  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
809  index[dim] = i;
810  if (auto elementID = prepareDenseElementsConstant(
811  loc, elementType, valueAttr, dim + 1, index)) {
812  operands.push_back(elementID);
813  } else {
814  return 0;
815  }
816  }
817  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
818  encodeInstructionInto(typesGlobalValues, opcode, operands);
819 
820  return resultID;
821 }
822 
823 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
824  bool isSpec) {
825  if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
826  return prepareConstantFp(loc, floatAttr, isSpec);
827  }
828  if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
829  return prepareConstantBool(loc, boolAttr, isSpec);
830  }
831  if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
832  return prepareConstantInt(loc, intAttr, isSpec);
833  }
834 
835  return 0;
836 }
837 
838 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
839  bool isSpec) {
840  if (!isSpec) {
841  // We can de-duplicate normal constants, but not specialization constants.
842  if (auto id = getConstantID(boolAttr)) {
843  return id;
844  }
845  }
846 
847  // Process the type for this bool literal
848  uint32_t typeID = 0;
849  if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
850  return 0;
851  }
852 
853  auto resultID = getNextID();
854  auto opcode = boolAttr.getValue()
855  ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
856  : spirv::Opcode::OpConstantTrue)
857  : (isSpec ? spirv::Opcode::OpSpecConstantFalse
858  : spirv::Opcode::OpConstantFalse);
859  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
860 
861  if (!isSpec) {
862  constIDMap[boolAttr] = resultID;
863  }
864  return resultID;
865 }
866 
867 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
868  bool isSpec) {
869  if (!isSpec) {
870  // We can de-duplicate normal constants, but not specialization constants.
871  if (auto id = getConstantID(intAttr)) {
872  return id;
873  }
874  }
875 
876  // Process the type for this integer literal
877  uint32_t typeID = 0;
878  if (failed(processType(loc, intAttr.getType(), typeID))) {
879  return 0;
880  }
881 
882  auto resultID = getNextID();
883  APInt value = intAttr.getValue();
884  unsigned bitwidth = value.getBitWidth();
885  bool isSigned = intAttr.getType().isSignedInteger();
886  auto opcode =
887  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
888 
889  switch (bitwidth) {
890  // According to SPIR-V spec, "When the type's bit width is less than
891  // 32-bits, the literal's value appears in the low-order bits of the word,
892  // and the high-order bits must be 0 for a floating-point type, or 0 for an
893  // integer type with Signedness of 0, or sign extended when Signedness
894  // is 1."
895  case 32:
896  case 16:
897  case 8: {
898  uint32_t word = 0;
899  if (isSigned) {
900  word = static_cast<int32_t>(value.getSExtValue());
901  } else {
902  word = static_cast<uint32_t>(value.getZExtValue());
903  }
904  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
905  } break;
906  // According to SPIR-V spec: "When the type's bit width is larger than one
907  // word, the literal’s low-order words appear first."
908  case 64: {
909  struct DoubleWord {
910  uint32_t word1;
911  uint32_t word2;
912  } words;
913  if (isSigned) {
914  words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
915  } else {
916  words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
917  }
918  encodeInstructionInto(typesGlobalValues, opcode,
919  {typeID, resultID, words.word1, words.word2});
920  } break;
921  default: {
922  std::string valueStr;
923  llvm::raw_string_ostream rss(valueStr);
924  value.print(rss, /*isSigned=*/false);
925 
926  emitError(loc, "cannot serialize ")
927  << bitwidth << "-bit integer literal: " << rss.str();
928  return 0;
929  }
930  }
931 
932  if (!isSpec) {
933  constIDMap[intAttr] = resultID;
934  }
935  return resultID;
936 }
937 
938 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
939  bool isSpec) {
940  if (!isSpec) {
941  // We can de-duplicate normal constants, but not specialization constants.
942  if (auto id = getConstantID(floatAttr)) {
943  return id;
944  }
945  }
946 
947  // Process the type for this float literal
948  uint32_t typeID = 0;
949  if (failed(processType(loc, floatAttr.getType(), typeID))) {
950  return 0;
951  }
952 
953  auto resultID = getNextID();
954  APFloat value = floatAttr.getValue();
955  APInt intValue = value.bitcastToAPInt();
956 
957  auto opcode =
958  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
959 
960  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
961  uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
962  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
963  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
964  struct DoubleWord {
965  uint32_t word1;
966  uint32_t word2;
967  } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
968  encodeInstructionInto(typesGlobalValues, opcode,
969  {typeID, resultID, words.word1, words.word2});
970  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
971  uint32_t word =
972  static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
973  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
974  } else {
975  std::string valueStr;
976  llvm::raw_string_ostream rss(valueStr);
977  value.print(rss);
978 
979  emitError(loc, "cannot serialize ")
980  << floatAttr.getType() << "-typed float literal: " << rss.str();
981  return 0;
982  }
983 
984  if (!isSpec) {
985  constIDMap[floatAttr] = resultID;
986  }
987  return resultID;
988 }
989 
990 //===----------------------------------------------------------------------===//
991 // Control flow
992 //===----------------------------------------------------------------------===//
993 
994 uint32_t Serializer::getOrCreateBlockID(Block *block) {
995  if (uint32_t id = getBlockID(block))
996  return id;
997  return blockIDMap[block] = getNextID();
998 }
999 
1000 #ifndef NDEBUG
1001 void Serializer::printBlock(Block *block, raw_ostream &os) {
1002  os << "block " << block << " (id = ";
1003  if (uint32_t id = getBlockID(block))
1004  os << id;
1005  else
1006  os << "unknown";
1007  os << ")\n";
1008 }
1009 #endif
1010 
1011 LogicalResult
1012 Serializer::processBlock(Block *block, bool omitLabel,
1013  function_ref<LogicalResult()> emitMerge) {
1014  LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1015  LLVM_DEBUG(block->print(llvm::dbgs()));
1016  LLVM_DEBUG(llvm::dbgs() << '\n');
1017  if (!omitLabel) {
1018  uint32_t blockID = getOrCreateBlockID(block);
1019  LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1020 
1021  // Emit OpLabel for this block.
1022  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1023  }
1024 
1025  // Emit OpPhi instructions for block arguments, if any.
1026  if (failed(emitPhiForBlockArguments(block)))
1027  return failure();
1028 
1029  // If we need to emit merge instructions, it must happen in this block. Check
1030  // whether we have other structured control flow ops, which will be expanded
1031  // into multiple basic blocks. If that's the case, we need to emit the merge
1032  // right now and then create new blocks for further serialization of the ops
1033  // in this block.
1034  if (emitMerge &&
1035  llvm::any_of(block->getOperations(),
1036  llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1037  if (failed(emitMerge()))
1038  return failure();
1039  emitMerge = nullptr;
1040 
1041  // Start a new block for further serialization.
1042  uint32_t blockID = getNextID();
1043  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1044  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1045  }
1046 
1047  // Process each op in this block except the terminator.
1048  for (Operation &op : llvm::drop_end(*block)) {
1049  if (failed(processOperation(&op)))
1050  return failure();
1051  }
1052 
1053  // Process the terminator.
1054  if (emitMerge)
1055  if (failed(emitMerge()))
1056  return failure();
1057  if (failed(processOperation(&block->back())))
1058  return failure();
1059 
1060  return success();
1061 }
1062 
1063 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1064  // Nothing to do if this block has no arguments or it's the entry block, which
1065  // always has the same arguments as the function signature.
1066  if (block->args_empty() || block->isEntryBlock())
1067  return success();
1068 
1069  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1070 
1071  // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1072  // A SPIR-V OpPhi instruction is of the syntax:
1073  // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1074  // So we need to collect all predecessor blocks and the arguments they send
1075  // to this block.
1077  for (Block *mlirPredecessor : block->getPredecessors()) {
1078  auto *terminator = mlirPredecessor->getTerminator();
1079  LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1080  LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1081  LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1082  // The predecessor here is the immediate one according to MLIR's IR
1083  // structure. It does not directly map to the incoming parent block for the
1084  // OpPhi instructions at SPIR-V binary level. This is because structured
1085  // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1086  // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1087  // the branch op jumping to the OpPhi's block then resides in the previous
1088  // structured control flow op's merge block.
1089  Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1090  LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1091  LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1092  if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1093  predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1094  } else if (auto branchCondOp =
1095  dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1096  std::optional<OperandRange> blockOperands;
1097  if (branchCondOp.getTrueTarget() == block) {
1098  blockOperands = branchCondOp.getTrueTargetOperands();
1099  } else {
1100  assert(branchCondOp.getFalseTarget() == block);
1101  blockOperands = branchCondOp.getFalseTargetOperands();
1102  }
1103 
1104  assert(!blockOperands->empty() &&
1105  "expected non-empty block operand range");
1106  predecessors.emplace_back(spirvPredecessor, *blockOperands);
1107  } else {
1108  return terminator->emitError("unimplemented terminator for Phi creation");
1109  }
1110  LLVM_DEBUG({
1111  llvm::dbgs() << " block arguments:\n";
1112  for (Value v : predecessors.back().second)
1113  llvm::dbgs() << " " << v << "\n";
1114  });
1115  }
1116 
1117  // Then create OpPhi instruction for each of the block argument.
1118  for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1119  BlockArgument arg = block->getArgument(argIndex);
1120 
1121  // Get the type <id> and result <id> for this OpPhi instruction.
1122  uint32_t phiTypeID = 0;
1123  if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1124  return failure();
1125  uint32_t phiID = getNextID();
1126 
1127  LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1128  << arg << " (id = " << phiID << ")\n");
1129 
1130  // Prepare the (value <id>, parent block <id>) pairs.
1131  SmallVector<uint32_t, 8> phiArgs;
1132  phiArgs.push_back(phiTypeID);
1133  phiArgs.push_back(phiID);
1134 
1135  for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1136  Value value = predecessors[predIndex].second[argIndex];
1137  uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1138  LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1139  << ") value " << value << ' ');
1140  // Each pair is a value <id> ...
1141  uint32_t valueId = getValueID(value);
1142  if (valueId == 0) {
1143  // The op generating this value hasn't been visited yet so we don't have
1144  // an <id> assigned yet. Record this to fix up later.
1145  LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1146  deferredPhiValues[value].push_back(functionBody.size() + 1 +
1147  phiArgs.size());
1148  } else {
1149  LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1150  }
1151  phiArgs.push_back(valueId);
1152  // ... and a parent block <id>.
1153  phiArgs.push_back(predBlockId);
1154  }
1155 
1156  encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1157  valueIDMap[arg] = phiID;
1158  }
1159 
1160  return success();
1161 }
1162 
1163 //===----------------------------------------------------------------------===//
1164 // Operation
1165 //===----------------------------------------------------------------------===//
1166 
1167 LogicalResult Serializer::encodeExtensionInstruction(
1168  Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1169  ArrayRef<uint32_t> operands) {
1170  // Check if the extension has been imported.
1171  auto &setID = extendedInstSetIDMap[extensionSetName];
1172  if (!setID) {
1173  setID = getNextID();
1174  SmallVector<uint32_t, 16> importOperands;
1175  importOperands.push_back(setID);
1176  spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1177  encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1178  importOperands);
1179  }
1180 
1181  // The first two operands are the result type <id> and result <id>. The set
1182  // <id> and the opcode need to be insert after this.
1183  if (operands.size() < 2) {
1184  return op->emitError("extended instructions must have a result encoding");
1185  }
1186  SmallVector<uint32_t, 8> extInstOperands;
1187  extInstOperands.reserve(operands.size() + 2);
1188  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1189  extInstOperands.push_back(setID);
1190  extInstOperands.push_back(extensionOpcode);
1191  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1192  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1193  extInstOperands);
1194  return success();
1195 }
1196 
1197 LogicalResult Serializer::processOperation(Operation *opInst) {
1198  LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1199 
1200  // First dispatch the ops that do not directly mirror an instruction from
1201  // the SPIR-V spec.
1203  .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1204  .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1205  .Case([&](spirv::BranchConditionalOp op) {
1206  return processBranchConditionalOp(op);
1207  })
1208  .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1209  .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1210  .Case([&](spirv::GlobalVariableOp op) {
1211  return processGlobalVariableOp(op);
1212  })
1213  .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1214  .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1215  .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1216  .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1217  .Case([&](spirv::SpecConstantCompositeOp op) {
1218  return processSpecConstantCompositeOp(op);
1219  })
1220  .Case([&](spirv::SpecConstantOperationOp op) {
1221  return processSpecConstantOperationOp(op);
1222  })
1223  .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1224  .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1225 
1226  // Then handle all the ops that directly mirror SPIR-V instructions with
1227  // auto-generated methods.
1228  .Default(
1229  [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1230 }
1231 
1232 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1233  StringRef extInstSet,
1234  uint32_t opcode) {
1235  SmallVector<uint32_t, 4> operands;
1236  Location loc = op->getLoc();
1237 
1238  uint32_t resultID = 0;
1239  if (op->getNumResults() != 0) {
1240  uint32_t resultTypeID = 0;
1241  if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1242  return failure();
1243  operands.push_back(resultTypeID);
1244 
1245  resultID = getNextID();
1246  operands.push_back(resultID);
1247  valueIDMap[op->getResult(0)] = resultID;
1248  };
1249 
1250  for (Value operand : op->getOperands())
1251  operands.push_back(getValueID(operand));
1252 
1253  if (failed(emitDebugLine(functionBody, loc)))
1254  return failure();
1255 
1256  if (extInstSet.empty()) {
1257  encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1258  operands);
1259  } else {
1260  if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1261  return failure();
1262  }
1263 
1264  if (op->getNumResults() != 0) {
1265  for (auto attr : op->getAttrs()) {
1266  if (failed(processDecoration(loc, resultID, attr)))
1267  return failure();
1268  }
1269  }
1270 
1271  return success();
1272 }
1273 
1274 LogicalResult Serializer::emitDecoration(uint32_t target,
1275  spirv::Decoration decoration,
1276  ArrayRef<uint32_t> params) {
1277  uint32_t wordCount = 3 + params.size();
1278  llvm::append_values(
1279  decorations,
1280  spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1281  static_cast<uint32_t>(decoration));
1282  llvm::append_range(decorations, params);
1283  return success();
1284 }
1285 
1286 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1287  Location loc) {
1288  if (!options.emitDebugInfo)
1289  return success();
1290 
1291  if (lastProcessedWasMergeInst) {
1292  lastProcessedWasMergeInst = false;
1293  return success();
1294  }
1295 
1296  auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1297  if (fileLoc)
1298  encodeInstructionInto(binary, spirv::Opcode::OpLine,
1299  {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1300  return success();
1301 }
1302 } // namespace spirv
1303 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
Definition: Serializer.cpp:36
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Definition: Serializer.cpp:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation & back()
Definition: Block.h:150
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:235
OpListType & getOperations()
Definition: Block.h:135
void print(raw_ostream &os)
bool args_empty()
Definition: Block.h:97
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:220
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Definition: Serializer.cpp:139
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
Definition: Serializer.cpp:85
LogicalResult serialize()
Serializes the remembered SPIR-V module.
Definition: Serializer.cpp:89
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
Definition: Serializer.cpp:113
SPIR-V struct type.
Definition: SPIRVTypes.h:293
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Definition: Serializer.cpp:212
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
Definition: Serialization.h:26
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Definition: Serialization.h:28