MLIR  19.0.0git
Serializer.cpp
Go to the documentation of this file.
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/ADT/bit.h"
27 #include "llvm/Support/Debug.h"
28 #include <cstdint>
29 #include <optional>
30 
31 #define DEBUG_TYPE "spirv-serialization"
32 
33 using namespace mlir;
34 
35 /// Returns the merge block if the given `op` is a structured control flow op.
36 /// Otherwise returns nullptr.
38  if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
39  return selectionOp.getMergeBlock();
40  if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
41  return loopOp.getMergeBlock();
42  return nullptr;
43 }
44 
45 /// Given a predecessor `block` for a block with arguments, returns the block
46 /// that should be used as the parent block for SPIR-V OpPhi instructions
47 /// corresponding to the block arguments.
48 static Block *getPhiIncomingBlock(Block *block) {
49  // If the predecessor block in question is the entry block for a
50  // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
51  if (block->isEntryBlock()) {
52  if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
53  // Then the incoming parent block for OpPhi should be the merge block of
54  // the structured control flow op before this loop.
55  Operation *op = loopOp.getOperation();
56  while ((op = op->getPrevNode()) != nullptr)
57  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
58  return incomingBlock;
59  // Or the enclosing block itself if no structured control flow ops
60  // exists before this loop.
61  return loopOp->getBlock();
62  }
63  }
64 
65  // Otherwise, we jump from the given predecessor block. Try to see if there is
66  // a structured control flow op inside it.
67  for (Operation &op : llvm::reverse(block->getOperations())) {
68  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
69  return incomingBlock;
70  }
71  return block;
72 }
73 
74 namespace mlir {
75 namespace spirv {
76 
77 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
78 /// the given `binary` vector.
79 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
80  ArrayRef<uint32_t> operands) {
81  uint32_t wordCount = 1 + operands.size();
82  binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
83  binary.append(operands.begin(), operands.end());
84 }
85 
86 Serializer::Serializer(spirv::ModuleOp module,
88  : module(module), mlirBuilder(module.getContext()), options(options) {}
89 
91  LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
92 
93  if (failed(module.verifyInvariants()))
94  return failure();
95 
96  // TODO: handle the other sections
97  processCapability();
98  processExtension();
99  processMemoryModel();
100  processDebugInfo();
101 
102  // Iterate over the module body to serialize it. Assumptions are that there is
103  // only one basic block in the moduleOp
104  for (auto &op : *module.getBody()) {
105  if (failed(processOperation(&op))) {
106  return failure();
107  }
108  }
109 
110  LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
111  return success();
112 }
113 
115  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
116  extensions.size() + extendedSets.size() +
117  memoryModel.size() + entryPoints.size() +
118  executionModes.size() + decorations.size() +
119  typesGlobalValues.size() + functions.size();
120 
121  binary.clear();
122  binary.reserve(moduleSize);
123 
124  spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
125  nextID);
126  binary.append(capabilities.begin(), capabilities.end());
127  binary.append(extensions.begin(), extensions.end());
128  binary.append(extendedSets.begin(), extendedSets.end());
129  binary.append(memoryModel.begin(), memoryModel.end());
130  binary.append(entryPoints.begin(), entryPoints.end());
131  binary.append(executionModes.begin(), executionModes.end());
132  binary.append(debug.begin(), debug.end());
133  binary.append(names.begin(), names.end());
134  binary.append(decorations.begin(), decorations.end());
135  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
136  binary.append(functions.begin(), functions.end());
137 }
138 
139 #ifndef NDEBUG
140 void Serializer::printValueIDMap(raw_ostream &os) {
141  os << "\n= Value <id> Map =\n\n";
142  for (auto valueIDPair : valueIDMap) {
143  Value val = valueIDPair.first;
144  os << " " << val << " "
145  << "id = " << valueIDPair.second << ' ';
146  if (auto *op = val.getDefiningOp()) {
147  os << "from op '" << op->getName() << "'";
148  } else if (auto arg = dyn_cast<BlockArgument>(val)) {
149  Block *block = arg.getOwner();
150  os << "from argument of block " << block << ' ';
151  os << " in op '" << block->getParentOp()->getName() << "'";
152  }
153  os << '\n';
154  }
155 }
156 #endif
157 
158 //===----------------------------------------------------------------------===//
159 // Module structure
160 //===----------------------------------------------------------------------===//
161 
162 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
163  auto funcID = funcIDMap.lookup(fnName);
164  if (!funcID) {
165  funcID = getNextID();
166  funcIDMap[fnName] = funcID;
167  }
168  return funcID;
169 }
170 
171 void Serializer::processCapability() {
172  for (auto cap : module.getVceTriple()->getCapabilities())
173  encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
174  {static_cast<uint32_t>(cap)});
175 }
176 
177 void Serializer::processDebugInfo() {
178  if (!options.emitDebugInfo)
179  return;
180  auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
181  auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
182  fileID = getNextID();
183  SmallVector<uint32_t, 16> operands;
184  operands.push_back(fileID);
185  spirv::encodeStringLiteralInto(operands, fileName);
186  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
187  // TODO: Encode more debug instructions.
188 }
189 
190 void Serializer::processExtension() {
192  for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
193  extName.clear();
194  spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
195  encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
196  }
197 }
198 
199 void Serializer::processMemoryModel() {
200  StringAttr memoryModelName = module.getMemoryModelAttrName();
201  auto mm = static_cast<uint32_t>(
202  module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
203  .getValue());
204 
205  StringAttr addressingModelName = module.getAddressingModelAttrName();
206  auto am = static_cast<uint32_t>(
207  module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
208  .getValue());
209 
210  encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
211 }
212 
213 static std::string getDecorationName(StringRef attrName) {
214  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
215  // expected FPFastMathMode.
216  if (attrName == "fp_fast_math_mode")
217  return "FPFastMathMode";
218 
219  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
220 }
221 
222 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
223  Decoration decoration,
224  Attribute attr) {
226  switch (decoration) {
227  case spirv::Decoration::LinkageAttributes: {
228  // Get the value of the Linkage Attributes
229  // e.g., LinkageAttributes=["linkageName", linkageType].
230  auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
231  auto linkageName = linkageAttr.getLinkageName();
232  auto linkageType = linkageAttr.getLinkageType().getValue();
233  // Encode the Linkage Name (string literal to uint32_t).
234  spirv::encodeStringLiteralInto(args, linkageName);
235  // Encode LinkageType & Add the Linkagetype to the args.
236  args.push_back(static_cast<uint32_t>(linkageType));
237  break;
238  }
239  case spirv::Decoration::FPFastMathMode:
240  if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
241  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
242  break;
243  }
244  return emitError(loc, "expected FPFastMathModeAttr attribute for ")
245  << stringifyDecoration(decoration);
246  case spirv::Decoration::Binding:
247  case spirv::Decoration::DescriptorSet:
248  case spirv::Decoration::Location:
249  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
250  args.push_back(intAttr.getValue().getZExtValue());
251  break;
252  }
253  return emitError(loc, "expected integer attribute for ")
254  << stringifyDecoration(decoration);
255  case spirv::Decoration::BuiltIn:
256  if (auto strAttr = dyn_cast<StringAttr>(attr)) {
257  auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
258  if (enumVal) {
259  args.push_back(static_cast<uint32_t>(*enumVal));
260  break;
261  }
262  return emitError(loc, "invalid ")
263  << stringifyDecoration(decoration) << " decoration attribute "
264  << strAttr.getValue();
265  }
266  return emitError(loc, "expected string attribute for ")
267  << stringifyDecoration(decoration);
268  case spirv::Decoration::Aliased:
269  case spirv::Decoration::AliasedPointer:
270  case spirv::Decoration::Flat:
271  case spirv::Decoration::NonReadable:
272  case spirv::Decoration::NonWritable:
273  case spirv::Decoration::NoPerspective:
274  case spirv::Decoration::NoSignedWrap:
275  case spirv::Decoration::NoUnsignedWrap:
276  case spirv::Decoration::RelaxedPrecision:
277  case spirv::Decoration::Restrict:
278  case spirv::Decoration::RestrictPointer:
279  case spirv::Decoration::NoContraction:
280  // For unit attributes and decoration attributes, the args list
281  // has no values so we do nothing.
282  if (isa<UnitAttr, DecorationAttr>(attr))
283  break;
284  return emitError(loc,
285  "expected unit attribute or decoration attribute for ")
286  << stringifyDecoration(decoration);
287  default:
288  return emitError(loc, "unhandled decoration ")
289  << stringifyDecoration(decoration);
290  }
291  return emitDecoration(resultID, decoration, args);
292 }
293 
294 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
295  NamedAttribute attr) {
296  StringRef attrName = attr.getName().strref();
297  std::string decorationName = getDecorationName(attrName);
298  std::optional<Decoration> decoration =
299  spirv::symbolizeDecoration(decorationName);
300  if (!decoration) {
301  return emitError(
302  loc, "non-argument attributes expected to have snake-case-ified "
303  "decoration name, unhandled attribute with name : ")
304  << attrName;
305  }
306  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
307 }
308 
309 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
310  assert(!name.empty() && "unexpected empty string for OpName");
311  if (!options.emitSymbolName)
312  return success();
313 
314  SmallVector<uint32_t, 4> nameOperands;
315  nameOperands.push_back(resultID);
316  spirv::encodeStringLiteralInto(nameOperands, name);
317  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
318  return success();
319 }
320 
321 template <>
322 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
323  Location loc, spirv::ArrayType type, uint32_t resultID) {
324  if (unsigned stride = type.getArrayStride()) {
325  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
326  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
327  }
328  return success();
329 }
330 
331 template <>
332 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
333  Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
334  if (unsigned stride = type.getArrayStride()) {
335  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
336  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
337  }
338  return success();
339 }
340 
341 LogicalResult Serializer::processMemberDecoration(
342  uint32_t structID,
343  const spirv::StructType::MemberDecorationInfo &memberDecoration) {
345  {structID, memberDecoration.memberIndex,
346  static_cast<uint32_t>(memberDecoration.decoration)});
347  if (memberDecoration.hasValue) {
348  args.push_back(memberDecoration.decorationValue);
349  }
350  encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
351  return success();
352 }
353 
354 //===----------------------------------------------------------------------===//
355 // Type
356 //===----------------------------------------------------------------------===//
357 
358 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
359 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
360 // PushConstant Storage Classes must be explicitly laid out."
361 bool Serializer::isInterfaceStructPtrType(Type type) const {
362  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
363  switch (ptrType.getStorageClass()) {
364  case spirv::StorageClass::PhysicalStorageBuffer:
365  case spirv::StorageClass::PushConstant:
366  case spirv::StorageClass::StorageBuffer:
367  case spirv::StorageClass::Uniform:
368  return isa<spirv::StructType>(ptrType.getPointeeType());
369  default:
370  break;
371  }
372  }
373  return false;
374 }
375 
376 LogicalResult Serializer::processType(Location loc, Type type,
377  uint32_t &typeID) {
378  // Maintains a set of names for nested identified struct types. This is used
379  // to properly serialize recursive references.
380  SetVector<StringRef> serializationCtx;
381  return processTypeImpl(loc, type, typeID, serializationCtx);
382 }
383 
385 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
386  SetVector<StringRef> &serializationCtx) {
387  typeID = getTypeID(type);
388  if (typeID)
389  return success();
390 
391  typeID = getNextID();
392  SmallVector<uint32_t, 4> operands;
393 
394  operands.push_back(typeID);
395  auto typeEnum = spirv::Opcode::OpTypeVoid;
396  bool deferSerialization = false;
397 
398  if ((isa<FunctionType>(type) &&
399  succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
400  operands))) ||
401  succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
402  deferSerialization, serializationCtx))) {
403  if (deferSerialization)
404  return success();
405 
406  typeIDMap[type] = typeID;
407 
408  encodeInstructionInto(typesGlobalValues, typeEnum, operands);
409 
410  if (recursiveStructInfos.count(type) != 0) {
411  // This recursive struct type is emitted already, now the OpTypePointer
412  // instructions referring to recursive references are emitted as well.
413  for (auto &ptrInfo : recursiveStructInfos[type]) {
414  // TODO: This might not work if more than 1 recursive reference is
415  // present in the struct.
416  SmallVector<uint32_t, 4> ptrOperands;
417  ptrOperands.push_back(ptrInfo.pointerTypeID);
418  ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
419  ptrOperands.push_back(typeIDMap[type]);
420 
421  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
422  ptrOperands);
423  }
424 
425  recursiveStructInfos[type].clear();
426  }
427 
428  return success();
429  }
430 
431  return failure();
432 }
433 
434 LogicalResult Serializer::prepareBasicType(
435  Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
436  SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
437  SetVector<StringRef> &serializationCtx) {
438  deferSerialization = false;
439 
440  if (isVoidType(type)) {
441  typeEnum = spirv::Opcode::OpTypeVoid;
442  return success();
443  }
444 
445  if (auto intType = dyn_cast<IntegerType>(type)) {
446  if (intType.getWidth() == 1) {
447  typeEnum = spirv::Opcode::OpTypeBool;
448  return success();
449  }
450 
451  typeEnum = spirv::Opcode::OpTypeInt;
452  operands.push_back(intType.getWidth());
453  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
454  // to preserve or validate.
455  // 0 indicates unsigned, or no signedness semantics
456  // 1 indicates signed semantics."
457  operands.push_back(intType.isSigned() ? 1 : 0);
458  return success();
459  }
460 
461  if (auto floatType = dyn_cast<FloatType>(type)) {
462  typeEnum = spirv::Opcode::OpTypeFloat;
463  operands.push_back(floatType.getWidth());
464  return success();
465  }
466 
467  if (auto vectorType = dyn_cast<VectorType>(type)) {
468  uint32_t elementTypeID = 0;
469  if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
470  serializationCtx))) {
471  return failure();
472  }
473  typeEnum = spirv::Opcode::OpTypeVector;
474  operands.push_back(elementTypeID);
475  operands.push_back(vectorType.getNumElements());
476  return success();
477  }
478 
479  if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
480  typeEnum = spirv::Opcode::OpTypeImage;
481  uint32_t sampledTypeID = 0;
482  if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
483  return failure();
484 
485  llvm::append_values(operands, sampledTypeID,
486  static_cast<uint32_t>(imageType.getDim()),
487  static_cast<uint32_t>(imageType.getDepthInfo()),
488  static_cast<uint32_t>(imageType.getArrayedInfo()),
489  static_cast<uint32_t>(imageType.getSamplingInfo()),
490  static_cast<uint32_t>(imageType.getSamplerUseInfo()),
491  static_cast<uint32_t>(imageType.getImageFormat()));
492  return success();
493  }
494 
495  if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
496  typeEnum = spirv::Opcode::OpTypeArray;
497  uint32_t elementTypeID = 0;
498  if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
499  serializationCtx))) {
500  return failure();
501  }
502  operands.push_back(elementTypeID);
503  if (auto elementCountID = prepareConstantInt(
504  loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
505  operands.push_back(elementCountID);
506  }
507  return processTypeDecoration(loc, arrayType, resultID);
508  }
509 
510  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
511  uint32_t pointeeTypeID = 0;
512  spirv::StructType pointeeStruct =
513  dyn_cast<spirv::StructType>(ptrType.getPointeeType());
514 
515  if (pointeeStruct && pointeeStruct.isIdentified() &&
516  serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
517  // A recursive reference to an enclosing struct is found.
518  //
519  // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
520  // class as operands.
521  SmallVector<uint32_t, 2> forwardPtrOperands;
522  forwardPtrOperands.push_back(resultID);
523  forwardPtrOperands.push_back(
524  static_cast<uint32_t>(ptrType.getStorageClass()));
525 
526  encodeInstructionInto(typesGlobalValues,
527  spirv::Opcode::OpTypeForwardPointer,
528  forwardPtrOperands);
529 
530  // 2. Find the pointee (enclosing) struct.
531  auto structType = spirv::StructType::getIdentified(
532  module.getContext(), pointeeStruct.getIdentifier());
533 
534  if (!structType)
535  return failure();
536 
537  // 3. Mark the OpTypePointer that is supposed to be emitted by this call
538  // as deferred.
539  deferSerialization = true;
540 
541  // 4. Record the info needed to emit the deferred OpTypePointer
542  // instruction when the enclosing struct is completely serialized.
543  recursiveStructInfos[structType].push_back(
544  {resultID, ptrType.getStorageClass()});
545  } else {
546  if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
547  serializationCtx)))
548  return failure();
549  }
550 
551  typeEnum = spirv::Opcode::OpTypePointer;
552  operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
553  operands.push_back(pointeeTypeID);
554 
555  if (isInterfaceStructPtrType(ptrType)) {
556  if (failed(emitDecoration(getTypeID(pointeeStruct),
557  spirv::Decoration::Block)))
558  return emitError(loc, "cannot decorate ")
559  << pointeeStruct << " with Block decoration";
560  }
561 
562  return success();
563  }
564 
565  if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
566  uint32_t elementTypeID = 0;
567  if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
568  elementTypeID, serializationCtx))) {
569  return failure();
570  }
571  typeEnum = spirv::Opcode::OpTypeRuntimeArray;
572  operands.push_back(elementTypeID);
573  return processTypeDecoration(loc, runtimeArrayType, resultID);
574  }
575 
576  if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
577  typeEnum = spirv::Opcode::OpTypeSampledImage;
578  uint32_t imageTypeID = 0;
579  if (failed(
580  processType(loc, sampledImageType.getImageType(), imageTypeID))) {
581  return failure();
582  }
583  operands.push_back(imageTypeID);
584  return success();
585  }
586 
587  if (auto structType = dyn_cast<spirv::StructType>(type)) {
588  if (structType.isIdentified()) {
589  if (failed(processName(resultID, structType.getIdentifier())))
590  return failure();
591  serializationCtx.insert(structType.getIdentifier());
592  }
593 
594  bool hasOffset = structType.hasOffset();
595  for (auto elementIndex :
596  llvm::seq<uint32_t>(0, structType.getNumElements())) {
597  uint32_t elementTypeID = 0;
598  if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
599  elementTypeID, serializationCtx))) {
600  return failure();
601  }
602  operands.push_back(elementTypeID);
603  if (hasOffset) {
604  // Decorate each struct member with an offset
606  elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
607  static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
608  if (failed(processMemberDecoration(resultID, offsetDecoration))) {
609  return emitError(loc, "cannot decorate ")
610  << elementIndex << "-th member of " << structType
611  << " with its offset";
612  }
613  }
614  }
616  structType.getMemberDecorations(memberDecorations);
617 
618  for (auto &memberDecoration : memberDecorations) {
619  if (failed(processMemberDecoration(resultID, memberDecoration))) {
620  return emitError(loc, "cannot decorate ")
621  << static_cast<uint32_t>(memberDecoration.memberIndex)
622  << "-th member of " << structType << " with "
623  << stringifyDecoration(memberDecoration.decoration);
624  }
625  }
626 
627  typeEnum = spirv::Opcode::OpTypeStruct;
628 
629  if (structType.isIdentified())
630  serializationCtx.remove(structType.getIdentifier());
631 
632  return success();
633  }
634 
635  if (auto cooperativeMatrixType =
636  dyn_cast<spirv::CooperativeMatrixType>(type)) {
637  uint32_t elementTypeID = 0;
638  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
639  elementTypeID, serializationCtx))) {
640  return failure();
641  }
642  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
643  auto getConstantOp = [&](uint32_t id) {
644  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
645  return prepareConstantInt(loc, attr);
646  };
647  llvm::append_values(
648  operands, elementTypeID,
649  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
650  getConstantOp(cooperativeMatrixType.getRows()),
651  getConstantOp(cooperativeMatrixType.getColumns()),
652  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
653  return success();
654  }
655 
656  if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
657  uint32_t elementTypeID = 0;
658  if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
659  elementTypeID, serializationCtx))) {
660  return failure();
661  }
662  typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
663  auto getConstantOp = [&](uint32_t id) {
664  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
665  return prepareConstantInt(loc, attr);
666  };
667  llvm::append_values(
668  operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
669  getConstantOp(jointMatrixType.getColumns()),
670  getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
671  getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
672  return success();
673  }
674 
675  if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
676  uint32_t elementTypeID = 0;
677  if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
678  serializationCtx))) {
679  return failure();
680  }
681  typeEnum = spirv::Opcode::OpTypeMatrix;
682  llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
683  return success();
684  }
685 
686  // TODO: Handle other types.
687  return emitError(loc, "unhandled type in serialization: ") << type;
688 }
689 
691 Serializer::prepareFunctionType(Location loc, FunctionType type,
692  spirv::Opcode &typeEnum,
693  SmallVectorImpl<uint32_t> &operands) {
694  typeEnum = spirv::Opcode::OpTypeFunction;
695  assert(type.getNumResults() <= 1 &&
696  "serialization supports only a single return value");
697  uint32_t resultID = 0;
698  if (failed(processType(
699  loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
700  resultID))) {
701  return failure();
702  }
703  operands.push_back(resultID);
704  for (auto &res : type.getInputs()) {
705  uint32_t argTypeID = 0;
706  if (failed(processType(loc, res, argTypeID))) {
707  return failure();
708  }
709  operands.push_back(argTypeID);
710  }
711  return success();
712 }
713 
714 //===----------------------------------------------------------------------===//
715 // Constant
716 //===----------------------------------------------------------------------===//
717 
718 uint32_t Serializer::prepareConstant(Location loc, Type constType,
719  Attribute valueAttr) {
720  if (auto id = prepareConstantScalar(loc, valueAttr)) {
721  return id;
722  }
723 
724  // This is a composite literal. We need to handle each component separately
725  // and then emit an OpConstantComposite for the whole.
726 
727  if (auto id = getConstantID(valueAttr)) {
728  return id;
729  }
730 
731  uint32_t typeID = 0;
732  if (failed(processType(loc, constType, typeID))) {
733  return 0;
734  }
735 
736  uint32_t resultID = 0;
737  if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
738  int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
739  SmallVector<uint64_t, 4> index(rank);
740  resultID = prepareDenseElementsConstant(loc, constType, attr,
741  /*dim=*/0, index);
742  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
743  resultID = prepareArrayConstant(loc, constType, arrayAttr);
744  }
745 
746  if (resultID == 0) {
747  emitError(loc, "cannot serialize attribute: ") << valueAttr;
748  return 0;
749  }
750 
751  constIDMap[valueAttr] = resultID;
752  return resultID;
753 }
754 
755 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
756  ArrayAttr attr) {
757  uint32_t typeID = 0;
758  if (failed(processType(loc, constType, typeID))) {
759  return 0;
760  }
761 
762  uint32_t resultID = getNextID();
763  SmallVector<uint32_t, 4> operands = {typeID, resultID};
764  operands.reserve(attr.size() + 2);
765  auto elementType = cast<spirv::ArrayType>(constType).getElementType();
766  for (Attribute elementAttr : attr) {
767  if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
768  operands.push_back(elementID);
769  } else {
770  return 0;
771  }
772  }
773  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
774  encodeInstructionInto(typesGlobalValues, opcode, operands);
775 
776  return resultID;
777 }
778 
779 // TODO: Turn the below function into iterative function, instead of
780 // recursive function.
781 uint32_t
782 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
783  DenseElementsAttr valueAttr, int dim,
785  auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
786  assert(dim <= shapedType.getRank());
787  if (shapedType.getRank() == dim) {
788  if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
789  return attr.getType().getElementType().isInteger(1)
790  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
791  : prepareConstantInt(loc,
792  attr.getValues<IntegerAttr>()[index]);
793  }
794  if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
795  return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
796  }
797  return 0;
798  }
799 
800  uint32_t typeID = 0;
801  if (failed(processType(loc, constType, typeID))) {
802  return 0;
803  }
804 
805  uint32_t resultID = getNextID();
806  SmallVector<uint32_t, 4> operands = {typeID, resultID};
807  operands.reserve(shapedType.getDimSize(dim) + 2);
808  auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
809  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
810  index[dim] = i;
811  if (auto elementID = prepareDenseElementsConstant(
812  loc, elementType, valueAttr, dim + 1, index)) {
813  operands.push_back(elementID);
814  } else {
815  return 0;
816  }
817  }
818  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
819  encodeInstructionInto(typesGlobalValues, opcode, operands);
820 
821  return resultID;
822 }
823 
824 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
825  bool isSpec) {
826  if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
827  return prepareConstantFp(loc, floatAttr, isSpec);
828  }
829  if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
830  return prepareConstantBool(loc, boolAttr, isSpec);
831  }
832  if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
833  return prepareConstantInt(loc, intAttr, isSpec);
834  }
835 
836  return 0;
837 }
838 
839 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
840  bool isSpec) {
841  if (!isSpec) {
842  // We can de-duplicate normal constants, but not specialization constants.
843  if (auto id = getConstantID(boolAttr)) {
844  return id;
845  }
846  }
847 
848  // Process the type for this bool literal
849  uint32_t typeID = 0;
850  if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
851  return 0;
852  }
853 
854  auto resultID = getNextID();
855  auto opcode = boolAttr.getValue()
856  ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
857  : spirv::Opcode::OpConstantTrue)
858  : (isSpec ? spirv::Opcode::OpSpecConstantFalse
859  : spirv::Opcode::OpConstantFalse);
860  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
861 
862  if (!isSpec) {
863  constIDMap[boolAttr] = resultID;
864  }
865  return resultID;
866 }
867 
868 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
869  bool isSpec) {
870  if (!isSpec) {
871  // We can de-duplicate normal constants, but not specialization constants.
872  if (auto id = getConstantID(intAttr)) {
873  return id;
874  }
875  }
876 
877  // Process the type for this integer literal
878  uint32_t typeID = 0;
879  if (failed(processType(loc, intAttr.getType(), typeID))) {
880  return 0;
881  }
882 
883  auto resultID = getNextID();
884  APInt value = intAttr.getValue();
885  unsigned bitwidth = value.getBitWidth();
886  bool isSigned = intAttr.getType().isSignedInteger();
887  auto opcode =
888  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
889 
890  switch (bitwidth) {
891  // According to SPIR-V spec, "When the type's bit width is less than
892  // 32-bits, the literal's value appears in the low-order bits of the word,
893  // and the high-order bits must be 0 for a floating-point type, or 0 for an
894  // integer type with Signedness of 0, or sign extended when Signedness
895  // is 1."
896  case 32:
897  case 16:
898  case 8: {
899  uint32_t word = 0;
900  if (isSigned) {
901  word = static_cast<int32_t>(value.getSExtValue());
902  } else {
903  word = static_cast<uint32_t>(value.getZExtValue());
904  }
905  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
906  } break;
907  // According to SPIR-V spec: "When the type's bit width is larger than one
908  // word, the literal’s low-order words appear first."
909  case 64: {
910  struct DoubleWord {
911  uint32_t word1;
912  uint32_t word2;
913  } words;
914  if (isSigned) {
915  words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
916  } else {
917  words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
918  }
919  encodeInstructionInto(typesGlobalValues, opcode,
920  {typeID, resultID, words.word1, words.word2});
921  } break;
922  default: {
923  std::string valueStr;
924  llvm::raw_string_ostream rss(valueStr);
925  value.print(rss, /*isSigned=*/false);
926 
927  emitError(loc, "cannot serialize ")
928  << bitwidth << "-bit integer literal: " << rss.str();
929  return 0;
930  }
931  }
932 
933  if (!isSpec) {
934  constIDMap[intAttr] = resultID;
935  }
936  return resultID;
937 }
938 
939 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
940  bool isSpec) {
941  if (!isSpec) {
942  // We can de-duplicate normal constants, but not specialization constants.
943  if (auto id = getConstantID(floatAttr)) {
944  return id;
945  }
946  }
947 
948  // Process the type for this float literal
949  uint32_t typeID = 0;
950  if (failed(processType(loc, floatAttr.getType(), typeID))) {
951  return 0;
952  }
953 
954  auto resultID = getNextID();
955  APFloat value = floatAttr.getValue();
956  APInt intValue = value.bitcastToAPInt();
957 
958  auto opcode =
959  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
960 
961  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
962  uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
963  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
964  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
965  struct DoubleWord {
966  uint32_t word1;
967  uint32_t word2;
968  } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
969  encodeInstructionInto(typesGlobalValues, opcode,
970  {typeID, resultID, words.word1, words.word2});
971  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
972  uint32_t word =
973  static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
974  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
975  } else {
976  std::string valueStr;
977  llvm::raw_string_ostream rss(valueStr);
978  value.print(rss);
979 
980  emitError(loc, "cannot serialize ")
981  << floatAttr.getType() << "-typed float literal: " << rss.str();
982  return 0;
983  }
984 
985  if (!isSpec) {
986  constIDMap[floatAttr] = resultID;
987  }
988  return resultID;
989 }
990 
991 //===----------------------------------------------------------------------===//
992 // Control flow
993 //===----------------------------------------------------------------------===//
994 
995 uint32_t Serializer::getOrCreateBlockID(Block *block) {
996  if (uint32_t id = getBlockID(block))
997  return id;
998  return blockIDMap[block] = getNextID();
999 }
1000 
1001 #ifndef NDEBUG
1002 void Serializer::printBlock(Block *block, raw_ostream &os) {
1003  os << "block " << block << " (id = ";
1004  if (uint32_t id = getBlockID(block))
1005  os << id;
1006  else
1007  os << "unknown";
1008  os << ")\n";
1009 }
1010 #endif
1011 
1013 Serializer::processBlock(Block *block, bool omitLabel,
1014  function_ref<LogicalResult()> emitMerge) {
1015  LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1016  LLVM_DEBUG(block->print(llvm::dbgs()));
1017  LLVM_DEBUG(llvm::dbgs() << '\n');
1018  if (!omitLabel) {
1019  uint32_t blockID = getOrCreateBlockID(block);
1020  LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1021 
1022  // Emit OpLabel for this block.
1023  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1024  }
1025 
1026  // Emit OpPhi instructions for block arguments, if any.
1027  if (failed(emitPhiForBlockArguments(block)))
1028  return failure();
1029 
1030  // If we need to emit merge instructions, it must happen in this block. Check
1031  // whether we have other structured control flow ops, which will be expanded
1032  // into multiple basic blocks. If that's the case, we need to emit the merge
1033  // right now and then create new blocks for further serialization of the ops
1034  // in this block.
1035  if (emitMerge &&
1036  llvm::any_of(block->getOperations(),
1037  llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1038  if (failed(emitMerge()))
1039  return failure();
1040  emitMerge = nullptr;
1041 
1042  // Start a new block for further serialization.
1043  uint32_t blockID = getNextID();
1044  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1045  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1046  }
1047 
1048  // Process each op in this block except the terminator.
1049  for (Operation &op : llvm::drop_end(*block)) {
1050  if (failed(processOperation(&op)))
1051  return failure();
1052  }
1053 
1054  // Process the terminator.
1055  if (emitMerge)
1056  if (failed(emitMerge()))
1057  return failure();
1058  if (failed(processOperation(&block->back())))
1059  return failure();
1060 
1061  return success();
1062 }
1063 
1064 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1065  // Nothing to do if this block has no arguments or it's the entry block, which
1066  // always has the same arguments as the function signature.
1067  if (block->args_empty() || block->isEntryBlock())
1068  return success();
1069 
1070  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1071 
1072  // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1073  // A SPIR-V OpPhi instruction is of the syntax:
1074  // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1075  // So we need to collect all predecessor blocks and the arguments they send
1076  // to this block.
1078  for (Block *mlirPredecessor : block->getPredecessors()) {
1079  auto *terminator = mlirPredecessor->getTerminator();
1080  LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1081  LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1082  LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1083  // The predecessor here is the immediate one according to MLIR's IR
1084  // structure. It does not directly map to the incoming parent block for the
1085  // OpPhi instructions at SPIR-V binary level. This is because structured
1086  // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1087  // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1088  // the branch op jumping to the OpPhi's block then resides in the previous
1089  // structured control flow op's merge block.
1090  Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1091  LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1092  LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1093  if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1094  predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1095  } else if (auto branchCondOp =
1096  dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1097  std::optional<OperandRange> blockOperands;
1098  if (branchCondOp.getTrueTarget() == block) {
1099  blockOperands = branchCondOp.getTrueTargetOperands();
1100  } else {
1101  assert(branchCondOp.getFalseTarget() == block);
1102  blockOperands = branchCondOp.getFalseTargetOperands();
1103  }
1104 
1105  assert(!blockOperands->empty() &&
1106  "expected non-empty block operand range");
1107  predecessors.emplace_back(spirvPredecessor, *blockOperands);
1108  } else {
1109  return terminator->emitError("unimplemented terminator for Phi creation");
1110  }
1111  LLVM_DEBUG({
1112  llvm::dbgs() << " block arguments:\n";
1113  for (Value v : predecessors.back().second)
1114  llvm::dbgs() << " " << v << "\n";
1115  });
1116  }
1117 
1118  // Then create OpPhi instruction for each of the block argument.
1119  for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1120  BlockArgument arg = block->getArgument(argIndex);
1121 
1122  // Get the type <id> and result <id> for this OpPhi instruction.
1123  uint32_t phiTypeID = 0;
1124  if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1125  return failure();
1126  uint32_t phiID = getNextID();
1127 
1128  LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1129  << arg << " (id = " << phiID << ")\n");
1130 
1131  // Prepare the (value <id>, parent block <id>) pairs.
1132  SmallVector<uint32_t, 8> phiArgs;
1133  phiArgs.push_back(phiTypeID);
1134  phiArgs.push_back(phiID);
1135 
1136  for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1137  Value value = predecessors[predIndex].second[argIndex];
1138  uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1139  LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1140  << ") value " << value << ' ');
1141  // Each pair is a value <id> ...
1142  uint32_t valueId = getValueID(value);
1143  if (valueId == 0) {
1144  // The op generating this value hasn't been visited yet so we don't have
1145  // an <id> assigned yet. Record this to fix up later.
1146  LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1147  deferredPhiValues[value].push_back(functionBody.size() + 1 +
1148  phiArgs.size());
1149  } else {
1150  LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1151  }
1152  phiArgs.push_back(valueId);
1153  // ... and a parent block <id>.
1154  phiArgs.push_back(predBlockId);
1155  }
1156 
1157  encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1158  valueIDMap[arg] = phiID;
1159  }
1160 
1161  return success();
1162 }
1163 
1164 //===----------------------------------------------------------------------===//
1165 // Operation
1166 //===----------------------------------------------------------------------===//
1167 
1168 LogicalResult Serializer::encodeExtensionInstruction(
1169  Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1170  ArrayRef<uint32_t> operands) {
1171  // Check if the extension has been imported.
1172  auto &setID = extendedInstSetIDMap[extensionSetName];
1173  if (!setID) {
1174  setID = getNextID();
1175  SmallVector<uint32_t, 16> importOperands;
1176  importOperands.push_back(setID);
1177  spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1178  encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1179  importOperands);
1180  }
1181 
1182  // The first two operands are the result type <id> and result <id>. The set
1183  // <id> and the opcode need to be insert after this.
1184  if (operands.size() < 2) {
1185  return op->emitError("extended instructions must have a result encoding");
1186  }
1187  SmallVector<uint32_t, 8> extInstOperands;
1188  extInstOperands.reserve(operands.size() + 2);
1189  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1190  extInstOperands.push_back(setID);
1191  extInstOperands.push_back(extensionOpcode);
1192  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1193  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1194  extInstOperands);
1195  return success();
1196 }
1197 
1198 LogicalResult Serializer::processOperation(Operation *opInst) {
1199  LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1200 
1201  // First dispatch the ops that do not directly mirror an instruction from
1202  // the SPIR-V spec.
1204  .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1205  .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1206  .Case([&](spirv::BranchConditionalOp op) {
1207  return processBranchConditionalOp(op);
1208  })
1209  .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1210  .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1211  .Case([&](spirv::GlobalVariableOp op) {
1212  return processGlobalVariableOp(op);
1213  })
1214  .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1215  .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1216  .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1217  .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1218  .Case([&](spirv::SpecConstantCompositeOp op) {
1219  return processSpecConstantCompositeOp(op);
1220  })
1221  .Case([&](spirv::SpecConstantOperationOp op) {
1222  return processSpecConstantOperationOp(op);
1223  })
1224  .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1225  .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1226 
1227  // Then handle all the ops that directly mirror SPIR-V instructions with
1228  // auto-generated methods.
1229  .Default(
1230  [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1231 }
1232 
1233 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1234  StringRef extInstSet,
1235  uint32_t opcode) {
1236  SmallVector<uint32_t, 4> operands;
1237  Location loc = op->getLoc();
1238 
1239  uint32_t resultID = 0;
1240  if (op->getNumResults() != 0) {
1241  uint32_t resultTypeID = 0;
1242  if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1243  return failure();
1244  operands.push_back(resultTypeID);
1245 
1246  resultID = getNextID();
1247  operands.push_back(resultID);
1248  valueIDMap[op->getResult(0)] = resultID;
1249  };
1250 
1251  for (Value operand : op->getOperands())
1252  operands.push_back(getValueID(operand));
1253 
1254  if (failed(emitDebugLine(functionBody, loc)))
1255  return failure();
1256 
1257  if (extInstSet.empty()) {
1258  encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1259  operands);
1260  } else {
1261  if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1262  return failure();
1263  }
1264 
1265  if (op->getNumResults() != 0) {
1266  for (auto attr : op->getAttrs()) {
1267  if (failed(processDecoration(loc, resultID, attr)))
1268  return failure();
1269  }
1270  }
1271 
1272  return success();
1273 }
1274 
1275 LogicalResult Serializer::emitDecoration(uint32_t target,
1276  spirv::Decoration decoration,
1277  ArrayRef<uint32_t> params) {
1278  uint32_t wordCount = 3 + params.size();
1279  llvm::append_values(
1280  decorations,
1281  spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1282  static_cast<uint32_t>(decoration));
1283  llvm::append_range(decorations, params);
1284  return success();
1285 }
1286 
1287 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1288  Location loc) {
1289  if (!options.emitDebugInfo)
1290  return success();
1291 
1292  if (lastProcessedWasMergeInst) {
1293  lastProcessedWasMergeInst = false;
1294  return success();
1295  }
1296 
1297  auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1298  if (fileLoc)
1299  encodeInstructionInto(binary, spirv::Opcode::OpLine,
1300  {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1301  return success();
1302 }
1303 } // namespace spirv
1304 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
Definition: Serializer.cpp:37
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Definition: Serializer.cpp:48
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
unsigned getNumArguments()
Definition: Block.h:125
Operation & back()
Definition: Block.h:149
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:234
OpListType & getOperations()
Definition: Block.h:134
void print(raw_ostream &os)
bool args_empty()
Definition: Block.h:96
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:202
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:548
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Definition: Serializer.cpp:140
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
Definition: Serializer.cpp:86
LogicalResult serialize()
Serializes the remembered SPIR-V module.
Definition: Serializer.cpp:90
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
Definition: Serializer.cpp:114
SPIR-V struct type.
Definition: SPIRVTypes.h:293
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:79
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Definition: Serializer.cpp:213
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
Definition: Serialization.h:27
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Definition: Serialization.h:29