MLIR  21.0.0git
Serializer.cpp
Go to the documentation of this file.
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
27 #include <cstdint>
28 #include <optional>
29 
30 #define DEBUG_TYPE "spirv-serialization"
31 
32 using namespace mlir;
33 
34 /// Returns the merge block if the given `op` is a structured control flow op.
35 /// Otherwise returns nullptr.
37  if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38  return selectionOp.getMergeBlock();
39  if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40  return loopOp.getMergeBlock();
41  return nullptr;
42 }
43 
44 /// Given a predecessor `block` for a block with arguments, returns the block
45 /// that should be used as the parent block for SPIR-V OpPhi instructions
46 /// corresponding to the block arguments.
47 static Block *getPhiIncomingBlock(Block *block) {
48  // If the predecessor block in question is the entry block for a
49  // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50  if (block->isEntryBlock()) {
51  if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52  // Then the incoming parent block for OpPhi should be the merge block of
53  // the structured control flow op before this loop.
54  Operation *op = loopOp.getOperation();
55  while ((op = op->getPrevNode()) != nullptr)
56  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57  return incomingBlock;
58  // Or the enclosing block itself if no structured control flow ops
59  // exists before this loop.
60  return loopOp->getBlock();
61  }
62  }
63 
64  // Otherwise, we jump from the given predecessor block. Try to see if there is
65  // a structured control flow op inside it.
66  for (Operation &op : llvm::reverse(block->getOperations())) {
67  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68  return incomingBlock;
69  }
70  return block;
71 }
72 
73 namespace mlir {
74 namespace spirv {
75 
76 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77 /// the given `binary` vector.
78 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79  ArrayRef<uint32_t> operands) {
80  uint32_t wordCount = 1 + operands.size();
81  binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
82  binary.append(operands.begin(), operands.end());
83 }
84 
85 Serializer::Serializer(spirv::ModuleOp module,
87  : module(module), mlirBuilder(module.getContext()), options(options) {}
88 
89 LogicalResult Serializer::serialize() {
90  LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91 
92  if (failed(module.verifyInvariants()))
93  return failure();
94 
95  // TODO: handle the other sections
96  processCapability();
97  processExtension();
98  processMemoryModel();
99  processDebugInfo();
100 
101  // Iterate over the module body to serialize it. Assumptions are that there is
102  // only one basic block in the moduleOp
103  for (auto &op : *module.getBody()) {
104  if (failed(processOperation(&op))) {
105  return failure();
106  }
107  }
108 
109  LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110  return success();
111 }
112 
114  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115  extensions.size() + extendedSets.size() +
116  memoryModel.size() + entryPoints.size() +
117  executionModes.size() + decorations.size() +
118  typesGlobalValues.size() + functions.size();
119 
120  binary.clear();
121  binary.reserve(moduleSize);
122 
123  spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
124  nextID);
125  binary.append(capabilities.begin(), capabilities.end());
126  binary.append(extensions.begin(), extensions.end());
127  binary.append(extendedSets.begin(), extendedSets.end());
128  binary.append(memoryModel.begin(), memoryModel.end());
129  binary.append(entryPoints.begin(), entryPoints.end());
130  binary.append(executionModes.begin(), executionModes.end());
131  binary.append(debug.begin(), debug.end());
132  binary.append(names.begin(), names.end());
133  binary.append(decorations.begin(), decorations.end());
134  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135  binary.append(functions.begin(), functions.end());
136 }
137 
138 #ifndef NDEBUG
139 void Serializer::printValueIDMap(raw_ostream &os) {
140  os << "\n= Value <id> Map =\n\n";
141  for (auto valueIDPair : valueIDMap) {
142  Value val = valueIDPair.first;
143  os << " " << val << " "
144  << "id = " << valueIDPair.second << ' ';
145  if (auto *op = val.getDefiningOp()) {
146  os << "from op '" << op->getName() << "'";
147  } else if (auto arg = dyn_cast<BlockArgument>(val)) {
148  Block *block = arg.getOwner();
149  os << "from argument of block " << block << ' ';
150  os << " in op '" << block->getParentOp()->getName() << "'";
151  }
152  os << '\n';
153  }
154 }
155 #endif
156 
157 //===----------------------------------------------------------------------===//
158 // Module structure
159 //===----------------------------------------------------------------------===//
160 
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162  auto funcID = funcIDMap.lookup(fnName);
163  if (!funcID) {
164  funcID = getNextID();
165  funcIDMap[fnName] = funcID;
166  }
167  return funcID;
168 }
169 
170 void Serializer::processCapability() {
171  for (auto cap : module.getVceTriple()->getCapabilities())
172  encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
173  {static_cast<uint32_t>(cap)});
174 }
175 
176 void Serializer::processDebugInfo() {
177  if (!options.emitDebugInfo)
178  return;
179  auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180  auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181  fileID = getNextID();
182  SmallVector<uint32_t, 16> operands;
183  operands.push_back(fileID);
184  spirv::encodeStringLiteralInto(operands, fileName);
185  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
186  // TODO: Encode more debug instructions.
187 }
188 
189 void Serializer::processExtension() {
191  for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192  extName.clear();
193  spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
194  encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
195  }
196 }
197 
198 void Serializer::processMemoryModel() {
199  StringAttr memoryModelName = module.getMemoryModelAttrName();
200  auto mm = static_cast<uint32_t>(
201  module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
202  .getValue());
203 
204  StringAttr addressingModelName = module.getAddressingModelAttrName();
205  auto am = static_cast<uint32_t>(
206  module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
207  .getValue());
208 
209  encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
210 }
211 
212 static std::string getDecorationName(StringRef attrName) {
213  // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
214  // expected FPFastMathMode.
215  if (attrName == "fp_fast_math_mode")
216  return "FPFastMathMode";
217  // similar here
218  if (attrName == "fp_rounding_mode")
219  return "FPRoundingMode";
220  // convertToCamelFromSnakeCase will not capitalize "INTEL".
221  if (attrName == "cache_control_load_intel")
222  return "CacheControlLoadINTEL";
223  if (attrName == "cache_control_store_intel")
224  return "CacheControlStoreINTEL";
225 
226  return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
227 }
228 
229 template <typename AttrTy, typename EmitF>
230 LogicalResult processDecorationList(Location loc, Decoration decoration,
231  Attribute attrList, StringRef attrName,
232  EmitF emitter) {
233  auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
234  if (!arrayAttr) {
235  return emitError(loc, "expecting array attribute of ")
236  << attrName << " for " << stringifyDecoration(decoration);
237  }
238  if (arrayAttr.empty()) {
239  return emitError(loc, "expecting non-empty array attribute of ")
240  << attrName << " for " << stringifyDecoration(decoration);
241  }
242  for (Attribute attr : arrayAttr.getValue()) {
243  auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244  if (!cacheControlAttr) {
245  return emitError(loc, "expecting array attribute of ")
246  << attrName << " for " << stringifyDecoration(decoration);
247  }
248  // This named attribute encodes several decorations. Emit one per
249  // element in the array.
250  if (failed(emitter(cacheControlAttr)))
251  return failure();
252  }
253  return success();
254 }
255 
256 LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
257  Decoration decoration,
258  Attribute attr) {
260  switch (decoration) {
261  case spirv::Decoration::LinkageAttributes: {
262  // Get the value of the Linkage Attributes
263  // e.g., LinkageAttributes=["linkageName", linkageType].
264  auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
265  auto linkageName = linkageAttr.getLinkageName();
266  auto linkageType = linkageAttr.getLinkageType().getValue();
267  // Encode the Linkage Name (string literal to uint32_t).
268  spirv::encodeStringLiteralInto(args, linkageName);
269  // Encode LinkageType & Add the Linkagetype to the args.
270  args.push_back(static_cast<uint32_t>(linkageType));
271  break;
272  }
273  case spirv::Decoration::FPFastMathMode:
274  if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
275  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
276  break;
277  }
278  return emitError(loc, "expected FPFastMathModeAttr attribute for ")
279  << stringifyDecoration(decoration);
280  case spirv::Decoration::FPRoundingMode:
281  if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
282  args.push_back(static_cast<uint32_t>(intAttr.getValue()));
283  break;
284  }
285  return emitError(loc, "expected FPRoundingModeAttr attribute for ")
286  << stringifyDecoration(decoration);
287  case spirv::Decoration::Binding:
288  case spirv::Decoration::DescriptorSet:
289  case spirv::Decoration::Location:
290  if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
291  args.push_back(intAttr.getValue().getZExtValue());
292  break;
293  }
294  return emitError(loc, "expected integer attribute for ")
295  << stringifyDecoration(decoration);
296  case spirv::Decoration::BuiltIn:
297  if (auto strAttr = dyn_cast<StringAttr>(attr)) {
298  auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
299  if (enumVal) {
300  args.push_back(static_cast<uint32_t>(*enumVal));
301  break;
302  }
303  return emitError(loc, "invalid ")
304  << stringifyDecoration(decoration) << " decoration attribute "
305  << strAttr.getValue();
306  }
307  return emitError(loc, "expected string attribute for ")
308  << stringifyDecoration(decoration);
309  case spirv::Decoration::Aliased:
310  case spirv::Decoration::AliasedPointer:
311  case spirv::Decoration::Flat:
312  case spirv::Decoration::NonReadable:
313  case spirv::Decoration::NonWritable:
314  case spirv::Decoration::NoPerspective:
315  case spirv::Decoration::NoSignedWrap:
316  case spirv::Decoration::NoUnsignedWrap:
317  case spirv::Decoration::RelaxedPrecision:
318  case spirv::Decoration::Restrict:
319  case spirv::Decoration::RestrictPointer:
320  case spirv::Decoration::NoContraction:
321  case spirv::Decoration::Constant:
322  // For unit attributes and decoration attributes, the args list
323  // has no values so we do nothing.
324  if (isa<UnitAttr, DecorationAttr>(attr))
325  break;
326  return emitError(loc,
327  "expected unit attribute or decoration attribute for ")
328  << stringifyDecoration(decoration);
329  case spirv::Decoration::CacheControlLoadINTEL:
330  return processDecorationList<CacheControlLoadINTELAttr>(
331  loc, decoration, attr, "CacheControlLoadINTEL",
332  [&](CacheControlLoadINTELAttr attr) {
333  unsigned cacheLevel = attr.getCacheLevel();
334  LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335  return emitDecoration(
336  resultID, decoration,
337  {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
338  });
339  case spirv::Decoration::CacheControlStoreINTEL:
340  return processDecorationList<CacheControlStoreINTELAttr>(
341  loc, decoration, attr, "CacheControlStoreINTEL",
342  [&](CacheControlStoreINTELAttr attr) {
343  unsigned cacheLevel = attr.getCacheLevel();
344  StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345  return emitDecoration(
346  resultID, decoration,
347  {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
348  });
349  default:
350  return emitError(loc, "unhandled decoration ")
351  << stringifyDecoration(decoration);
352  }
353  return emitDecoration(resultID, decoration, args);
354 }
355 
356 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
357  NamedAttribute attr) {
358  StringRef attrName = attr.getName().strref();
359  std::string decorationName = getDecorationName(attrName);
360  std::optional<Decoration> decoration =
361  spirv::symbolizeDecoration(decorationName);
362  if (!decoration) {
363  return emitError(
364  loc, "non-argument attributes expected to have snake-case-ified "
365  "decoration name, unhandled attribute with name : ")
366  << attrName;
367  }
368  return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
369 }
370 
371 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372  assert(!name.empty() && "unexpected empty string for OpName");
373  if (!options.emitSymbolName)
374  return success();
375 
376  SmallVector<uint32_t, 4> nameOperands;
377  nameOperands.push_back(resultID);
378  spirv::encodeStringLiteralInto(nameOperands, name);
379  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
380  return success();
381 }
382 
383 template <>
384 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
385  Location loc, spirv::ArrayType type, uint32_t resultID) {
386  if (unsigned stride = type.getArrayStride()) {
387  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
388  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
389  }
390  return success();
391 }
392 
393 template <>
394 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
395  Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
396  if (unsigned stride = type.getArrayStride()) {
397  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
398  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
399  }
400  return success();
401 }
402 
403 LogicalResult Serializer::processMemberDecoration(
404  uint32_t structID,
405  const spirv::StructType::MemberDecorationInfo &memberDecoration) {
407  {structID, memberDecoration.memberIndex,
408  static_cast<uint32_t>(memberDecoration.decoration)});
409  if (memberDecoration.hasValue) {
410  args.push_back(memberDecoration.decorationValue);
411  }
412  encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
413  return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // Type
418 //===----------------------------------------------------------------------===//
419 
420 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
421 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
422 // PushConstant Storage Classes must be explicitly laid out."
423 bool Serializer::isInterfaceStructPtrType(Type type) const {
424  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
425  switch (ptrType.getStorageClass()) {
426  case spirv::StorageClass::PhysicalStorageBuffer:
427  case spirv::StorageClass::PushConstant:
428  case spirv::StorageClass::StorageBuffer:
429  case spirv::StorageClass::Uniform:
430  return isa<spirv::StructType>(ptrType.getPointeeType());
431  default:
432  break;
433  }
434  }
435  return false;
436 }
437 
438 LogicalResult Serializer::processType(Location loc, Type type,
439  uint32_t &typeID) {
440  // Maintains a set of names for nested identified struct types. This is used
441  // to properly serialize recursive references.
442  SetVector<StringRef> serializationCtx;
443  return processTypeImpl(loc, type, typeID, serializationCtx);
444 }
445 
446 LogicalResult
447 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
448  SetVector<StringRef> &serializationCtx) {
449  typeID = getTypeID(type);
450  if (typeID)
451  return success();
452 
453  typeID = getNextID();
454  SmallVector<uint32_t, 4> operands;
455 
456  operands.push_back(typeID);
457  auto typeEnum = spirv::Opcode::OpTypeVoid;
458  bool deferSerialization = false;
459 
460  if ((isa<FunctionType>(type) &&
461  succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
462  operands))) ||
463  succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
464  deferSerialization, serializationCtx))) {
465  if (deferSerialization)
466  return success();
467 
468  typeIDMap[type] = typeID;
469 
470  encodeInstructionInto(typesGlobalValues, typeEnum, operands);
471 
472  if (recursiveStructInfos.count(type) != 0) {
473  // This recursive struct type is emitted already, now the OpTypePointer
474  // instructions referring to recursive references are emitted as well.
475  for (auto &ptrInfo : recursiveStructInfos[type]) {
476  // TODO: This might not work if more than 1 recursive reference is
477  // present in the struct.
478  SmallVector<uint32_t, 4> ptrOperands;
479  ptrOperands.push_back(ptrInfo.pointerTypeID);
480  ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
481  ptrOperands.push_back(typeIDMap[type]);
482 
483  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
484  ptrOperands);
485  }
486 
487  recursiveStructInfos[type].clear();
488  }
489 
490  return success();
491  }
492 
493  return failure();
494 }
495 
496 LogicalResult Serializer::prepareBasicType(
497  Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
498  SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
499  SetVector<StringRef> &serializationCtx) {
500  deferSerialization = false;
501 
502  if (isVoidType(type)) {
503  typeEnum = spirv::Opcode::OpTypeVoid;
504  return success();
505  }
506 
507  if (auto intType = dyn_cast<IntegerType>(type)) {
508  if (intType.getWidth() == 1) {
509  typeEnum = spirv::Opcode::OpTypeBool;
510  return success();
511  }
512 
513  typeEnum = spirv::Opcode::OpTypeInt;
514  operands.push_back(intType.getWidth());
515  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
516  // to preserve or validate.
517  // 0 indicates unsigned, or no signedness semantics
518  // 1 indicates signed semantics."
519  operands.push_back(intType.isSigned() ? 1 : 0);
520  return success();
521  }
522 
523  if (auto floatType = dyn_cast<FloatType>(type)) {
524  typeEnum = spirv::Opcode::OpTypeFloat;
525  operands.push_back(floatType.getWidth());
526  if (floatType.isBF16()) {
527  operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
528  }
529  return success();
530  }
531 
532  if (auto vectorType = dyn_cast<VectorType>(type)) {
533  uint32_t elementTypeID = 0;
534  if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
535  serializationCtx))) {
536  return failure();
537  }
538  typeEnum = spirv::Opcode::OpTypeVector;
539  operands.push_back(elementTypeID);
540  operands.push_back(vectorType.getNumElements());
541  return success();
542  }
543 
544  if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
545  typeEnum = spirv::Opcode::OpTypeImage;
546  uint32_t sampledTypeID = 0;
547  if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
548  return failure();
549 
550  llvm::append_values(operands, sampledTypeID,
551  static_cast<uint32_t>(imageType.getDim()),
552  static_cast<uint32_t>(imageType.getDepthInfo()),
553  static_cast<uint32_t>(imageType.getArrayedInfo()),
554  static_cast<uint32_t>(imageType.getSamplingInfo()),
555  static_cast<uint32_t>(imageType.getSamplerUseInfo()),
556  static_cast<uint32_t>(imageType.getImageFormat()));
557  return success();
558  }
559 
560  if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
561  typeEnum = spirv::Opcode::OpTypeArray;
562  uint32_t elementTypeID = 0;
563  if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
564  serializationCtx))) {
565  return failure();
566  }
567  operands.push_back(elementTypeID);
568  if (auto elementCountID = prepareConstantInt(
569  loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
570  operands.push_back(elementCountID);
571  }
572  return processTypeDecoration(loc, arrayType, resultID);
573  }
574 
575  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
576  uint32_t pointeeTypeID = 0;
577  spirv::StructType pointeeStruct =
578  dyn_cast<spirv::StructType>(ptrType.getPointeeType());
579 
580  if (pointeeStruct && pointeeStruct.isIdentified() &&
581  serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
582  // A recursive reference to an enclosing struct is found.
583  //
584  // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
585  // class as operands.
586  SmallVector<uint32_t, 2> forwardPtrOperands;
587  forwardPtrOperands.push_back(resultID);
588  forwardPtrOperands.push_back(
589  static_cast<uint32_t>(ptrType.getStorageClass()));
590 
591  encodeInstructionInto(typesGlobalValues,
592  spirv::Opcode::OpTypeForwardPointer,
593  forwardPtrOperands);
594 
595  // 2. Find the pointee (enclosing) struct.
596  auto structType = spirv::StructType::getIdentified(
597  module.getContext(), pointeeStruct.getIdentifier());
598 
599  if (!structType)
600  return failure();
601 
602  // 3. Mark the OpTypePointer that is supposed to be emitted by this call
603  // as deferred.
604  deferSerialization = true;
605 
606  // 4. Record the info needed to emit the deferred OpTypePointer
607  // instruction when the enclosing struct is completely serialized.
608  recursiveStructInfos[structType].push_back(
609  {resultID, ptrType.getStorageClass()});
610  } else {
611  if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
612  serializationCtx)))
613  return failure();
614  }
615 
616  typeEnum = spirv::Opcode::OpTypePointer;
617  operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
618  operands.push_back(pointeeTypeID);
619 
620  if (isInterfaceStructPtrType(ptrType)) {
621  if (failed(emitDecoration(getTypeID(pointeeStruct),
622  spirv::Decoration::Block)))
623  return emitError(loc, "cannot decorate ")
624  << pointeeStruct << " with Block decoration";
625  }
626 
627  return success();
628  }
629 
630  if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
631  uint32_t elementTypeID = 0;
632  if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
633  elementTypeID, serializationCtx))) {
634  return failure();
635  }
636  typeEnum = spirv::Opcode::OpTypeRuntimeArray;
637  operands.push_back(elementTypeID);
638  return processTypeDecoration(loc, runtimeArrayType, resultID);
639  }
640 
641  if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
642  typeEnum = spirv::Opcode::OpTypeSampledImage;
643  uint32_t imageTypeID = 0;
644  if (failed(
645  processType(loc, sampledImageType.getImageType(), imageTypeID))) {
646  return failure();
647  }
648  operands.push_back(imageTypeID);
649  return success();
650  }
651 
652  if (auto structType = dyn_cast<spirv::StructType>(type)) {
653  if (structType.isIdentified()) {
654  if (failed(processName(resultID, structType.getIdentifier())))
655  return failure();
656  serializationCtx.insert(structType.getIdentifier());
657  }
658 
659  bool hasOffset = structType.hasOffset();
660  for (auto elementIndex :
661  llvm::seq<uint32_t>(0, structType.getNumElements())) {
662  uint32_t elementTypeID = 0;
663  if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
664  elementTypeID, serializationCtx))) {
665  return failure();
666  }
667  operands.push_back(elementTypeID);
668  if (hasOffset) {
669  // Decorate each struct member with an offset
671  elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
672  static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
673  if (failed(processMemberDecoration(resultID, offsetDecoration))) {
674  return emitError(loc, "cannot decorate ")
675  << elementIndex << "-th member of " << structType
676  << " with its offset";
677  }
678  }
679  }
681  structType.getMemberDecorations(memberDecorations);
682 
683  for (auto &memberDecoration : memberDecorations) {
684  if (failed(processMemberDecoration(resultID, memberDecoration))) {
685  return emitError(loc, "cannot decorate ")
686  << static_cast<uint32_t>(memberDecoration.memberIndex)
687  << "-th member of " << structType << " with "
688  << stringifyDecoration(memberDecoration.decoration);
689  }
690  }
691 
692  typeEnum = spirv::Opcode::OpTypeStruct;
693 
694  if (structType.isIdentified())
695  serializationCtx.remove(structType.getIdentifier());
696 
697  return success();
698  }
699 
700  if (auto cooperativeMatrixType =
701  dyn_cast<spirv::CooperativeMatrixType>(type)) {
702  uint32_t elementTypeID = 0;
703  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
704  elementTypeID, serializationCtx))) {
705  return failure();
706  }
707  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
708  auto getConstantOp = [&](uint32_t id) {
709  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
710  return prepareConstantInt(loc, attr);
711  };
712  llvm::append_values(
713  operands, elementTypeID,
714  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
715  getConstantOp(cooperativeMatrixType.getRows()),
716  getConstantOp(cooperativeMatrixType.getColumns()),
717  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
718  return success();
719  }
720 
721  if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
722  uint32_t elementTypeID = 0;
723  if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
724  serializationCtx))) {
725  return failure();
726  }
727  typeEnum = spirv::Opcode::OpTypeMatrix;
728  llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
729  return success();
730  }
731 
732  // TODO: Handle other types.
733  return emitError(loc, "unhandled type in serialization: ") << type;
734 }
735 
736 LogicalResult
737 Serializer::prepareFunctionType(Location loc, FunctionType type,
738  spirv::Opcode &typeEnum,
739  SmallVectorImpl<uint32_t> &operands) {
740  typeEnum = spirv::Opcode::OpTypeFunction;
741  assert(type.getNumResults() <= 1 &&
742  "serialization supports only a single return value");
743  uint32_t resultID = 0;
744  if (failed(processType(
745  loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
746  resultID))) {
747  return failure();
748  }
749  operands.push_back(resultID);
750  for (auto &res : type.getInputs()) {
751  uint32_t argTypeID = 0;
752  if (failed(processType(loc, res, argTypeID))) {
753  return failure();
754  }
755  operands.push_back(argTypeID);
756  }
757  return success();
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // Constant
762 //===----------------------------------------------------------------------===//
763 
764 uint32_t Serializer::prepareConstant(Location loc, Type constType,
765  Attribute valueAttr) {
766  if (auto id = prepareConstantScalar(loc, valueAttr)) {
767  return id;
768  }
769 
770  // This is a composite literal. We need to handle each component separately
771  // and then emit an OpConstantComposite for the whole.
772 
773  if (auto id = getConstantID(valueAttr)) {
774  return id;
775  }
776 
777  uint32_t typeID = 0;
778  if (failed(processType(loc, constType, typeID))) {
779  return 0;
780  }
781 
782  uint32_t resultID = 0;
783  if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
784  int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
785  SmallVector<uint64_t, 4> index(rank);
786  resultID = prepareDenseElementsConstant(loc, constType, attr,
787  /*dim=*/0, index);
788  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
789  resultID = prepareArrayConstant(loc, constType, arrayAttr);
790  }
791 
792  if (resultID == 0) {
793  emitError(loc, "cannot serialize attribute: ") << valueAttr;
794  return 0;
795  }
796 
797  constIDMap[valueAttr] = resultID;
798  return resultID;
799 }
800 
801 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
802  ArrayAttr attr) {
803  uint32_t typeID = 0;
804  if (failed(processType(loc, constType, typeID))) {
805  return 0;
806  }
807 
808  uint32_t resultID = getNextID();
809  SmallVector<uint32_t, 4> operands = {typeID, resultID};
810  operands.reserve(attr.size() + 2);
811  auto elementType = cast<spirv::ArrayType>(constType).getElementType();
812  for (Attribute elementAttr : attr) {
813  if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
814  operands.push_back(elementID);
815  } else {
816  return 0;
817  }
818  }
819  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
820  encodeInstructionInto(typesGlobalValues, opcode, operands);
821 
822  return resultID;
823 }
824 
825 // TODO: Turn the below function into iterative function, instead of
826 // recursive function.
827 uint32_t
828 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
829  DenseElementsAttr valueAttr, int dim,
831  auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
832  assert(dim <= shapedType.getRank());
833  if (shapedType.getRank() == dim) {
834  if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
835  return attr.getType().getElementType().isInteger(1)
836  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
837  : prepareConstantInt(loc,
838  attr.getValues<IntegerAttr>()[index]);
839  }
840  if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
841  return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
842  }
843  return 0;
844  }
845 
846  uint32_t typeID = 0;
847  if (failed(processType(loc, constType, typeID))) {
848  return 0;
849  }
850 
851  int64_t numberOfConstituents = shapedType.getDimSize(dim);
852  uint32_t resultID = getNextID();
853  SmallVector<uint32_t, 4> operands = {typeID, resultID};
854  auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
855 
856  // "If the Result Type is a cooperative matrix type, then there must be only
857  // one Constituent, with scalar type matching the cooperative matrix Component
858  // Type, and all components of the matrix are initialized to that value."
859  // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
860  if (isa<spirv::CooperativeMatrixType>(constType)) {
861  if (!valueAttr.isSplat()) {
862  emitError(
863  loc,
864  "cannot serialize a non-splat value for a cooperative matrix type");
865  return 0;
866  }
867  // numberOfConstituents is 1, so we only need one more elements in the
868  // SmallVector, so the total is 3 (1 + 2).
869  operands.reserve(3);
870  // We set dim directly to `shapedType.getRank()` so the recursive call
871  // directly returns the scalar type.
872  if (auto elementID = prepareDenseElementsConstant(
873  loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
874  operands.push_back(elementID);
875  } else {
876  return 0;
877  }
878  } else {
879  operands.reserve(numberOfConstituents + 2);
880  for (int i = 0; i < numberOfConstituents; ++i) {
881  index[dim] = i;
882  if (auto elementID = prepareDenseElementsConstant(
883  loc, elementType, valueAttr, dim + 1, index)) {
884  operands.push_back(elementID);
885  } else {
886  return 0;
887  }
888  }
889  }
890  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
891  encodeInstructionInto(typesGlobalValues, opcode, operands);
892 
893  return resultID;
894 }
895 
896 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
897  bool isSpec) {
898  if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
899  return prepareConstantFp(loc, floatAttr, isSpec);
900  }
901  if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
902  return prepareConstantBool(loc, boolAttr, isSpec);
903  }
904  if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
905  return prepareConstantInt(loc, intAttr, isSpec);
906  }
907 
908  return 0;
909 }
910 
911 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
912  bool isSpec) {
913  if (!isSpec) {
914  // We can de-duplicate normal constants, but not specialization constants.
915  if (auto id = getConstantID(boolAttr)) {
916  return id;
917  }
918  }
919 
920  // Process the type for this bool literal
921  uint32_t typeID = 0;
922  if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
923  return 0;
924  }
925 
926  auto resultID = getNextID();
927  auto opcode = boolAttr.getValue()
928  ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
929  : spirv::Opcode::OpConstantTrue)
930  : (isSpec ? spirv::Opcode::OpSpecConstantFalse
931  : spirv::Opcode::OpConstantFalse);
932  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
933 
934  if (!isSpec) {
935  constIDMap[boolAttr] = resultID;
936  }
937  return resultID;
938 }
939 
940 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
941  bool isSpec) {
942  if (!isSpec) {
943  // We can de-duplicate normal constants, but not specialization constants.
944  if (auto id = getConstantID(intAttr)) {
945  return id;
946  }
947  }
948 
949  // Process the type for this integer literal
950  uint32_t typeID = 0;
951  if (failed(processType(loc, intAttr.getType(), typeID))) {
952  return 0;
953  }
954 
955  auto resultID = getNextID();
956  APInt value = intAttr.getValue();
957  unsigned bitwidth = value.getBitWidth();
958  bool isSigned = intAttr.getType().isSignedInteger();
959  auto opcode =
960  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
961 
962  switch (bitwidth) {
963  // According to SPIR-V spec, "When the type's bit width is less than
964  // 32-bits, the literal's value appears in the low-order bits of the word,
965  // and the high-order bits must be 0 for a floating-point type, or 0 for an
966  // integer type with Signedness of 0, or sign extended when Signedness
967  // is 1."
968  case 32:
969  case 16:
970  case 8: {
971  uint32_t word = 0;
972  if (isSigned) {
973  word = static_cast<int32_t>(value.getSExtValue());
974  } else {
975  word = static_cast<uint32_t>(value.getZExtValue());
976  }
977  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
978  } break;
979  // According to SPIR-V spec: "When the type's bit width is larger than one
980  // word, the literal’s low-order words appear first."
981  case 64: {
982  struct DoubleWord {
983  uint32_t word1;
984  uint32_t word2;
985  } words;
986  if (isSigned) {
987  words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
988  } else {
989  words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
990  }
991  encodeInstructionInto(typesGlobalValues, opcode,
992  {typeID, resultID, words.word1, words.word2});
993  } break;
994  default: {
995  std::string valueStr;
996  llvm::raw_string_ostream rss(valueStr);
997  value.print(rss, /*isSigned=*/false);
998 
999  emitError(loc, "cannot serialize ")
1000  << bitwidth << "-bit integer literal: " << valueStr;
1001  return 0;
1002  }
1003  }
1004 
1005  if (!isSpec) {
1006  constIDMap[intAttr] = resultID;
1007  }
1008  return resultID;
1009 }
1010 
1011 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1012  bool isSpec) {
1013  if (!isSpec) {
1014  // We can de-duplicate normal constants, but not specialization constants.
1015  if (auto id = getConstantID(floatAttr)) {
1016  return id;
1017  }
1018  }
1019 
1020  // Process the type for this float literal
1021  uint32_t typeID = 0;
1022  if (failed(processType(loc, floatAttr.getType(), typeID))) {
1023  return 0;
1024  }
1025 
1026  auto resultID = getNextID();
1027  APFloat value = floatAttr.getValue();
1028  const llvm::fltSemantics *semantics = &value.getSemantics();
1029 
1030  auto opcode =
1031  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1032 
1033  if (semantics == &APFloat::IEEEsingle()) {
1034  uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1035  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1036  } else if (semantics == &APFloat::IEEEdouble()) {
1037  struct DoubleWord {
1038  uint32_t word1;
1039  uint32_t word2;
1040  } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1041  encodeInstructionInto(typesGlobalValues, opcode,
1042  {typeID, resultID, words.word1, words.word2});
1043  } else if (semantics == &APFloat::IEEEhalf() ||
1044  semantics == &APFloat::BFloat()) {
1045  uint32_t word =
1046  static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1047  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1048  } else {
1049  std::string valueStr;
1050  llvm::raw_string_ostream rss(valueStr);
1051  value.print(rss);
1052 
1053  emitError(loc, "cannot serialize ")
1054  << floatAttr.getType() << "-typed float literal: " << valueStr;
1055  return 0;
1056  }
1057 
1058  if (!isSpec) {
1059  constIDMap[floatAttr] = resultID;
1060  }
1061  return resultID;
1062 }
1063 
1064 //===----------------------------------------------------------------------===//
1065 // Control flow
1066 //===----------------------------------------------------------------------===//
1067 
1068 uint32_t Serializer::getOrCreateBlockID(Block *block) {
1069  if (uint32_t id = getBlockID(block))
1070  return id;
1071  return blockIDMap[block] = getNextID();
1072 }
1073 
1074 #ifndef NDEBUG
1075 void Serializer::printBlock(Block *block, raw_ostream &os) {
1076  os << "block " << block << " (id = ";
1077  if (uint32_t id = getBlockID(block))
1078  os << id;
1079  else
1080  os << "unknown";
1081  os << ")\n";
1082 }
1083 #endif
1084 
1085 LogicalResult
1086 Serializer::processBlock(Block *block, bool omitLabel,
1087  function_ref<LogicalResult()> emitMerge) {
1088  LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1089  LLVM_DEBUG(block->print(llvm::dbgs()));
1090  LLVM_DEBUG(llvm::dbgs() << '\n');
1091  if (!omitLabel) {
1092  uint32_t blockID = getOrCreateBlockID(block);
1093  LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1094 
1095  // Emit OpLabel for this block.
1096  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1097  }
1098 
1099  // Emit OpPhi instructions for block arguments, if any.
1100  if (failed(emitPhiForBlockArguments(block)))
1101  return failure();
1102 
1103  // If we need to emit merge instructions, it must happen in this block. Check
1104  // whether we have other structured control flow ops, which will be expanded
1105  // into multiple basic blocks. If that's the case, we need to emit the merge
1106  // right now and then create new blocks for further serialization of the ops
1107  // in this block.
1108  if (emitMerge &&
1109  llvm::any_of(block->getOperations(),
1110  llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1111  if (failed(emitMerge()))
1112  return failure();
1113  emitMerge = nullptr;
1114 
1115  // Start a new block for further serialization.
1116  uint32_t blockID = getNextID();
1117  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1118  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1119  }
1120 
1121  // Process each op in this block except the terminator.
1122  for (Operation &op : llvm::drop_end(*block)) {
1123  if (failed(processOperation(&op)))
1124  return failure();
1125  }
1126 
1127  // Process the terminator.
1128  if (emitMerge)
1129  if (failed(emitMerge()))
1130  return failure();
1131  if (failed(processOperation(&block->back())))
1132  return failure();
1133 
1134  return success();
1135 }
1136 
1137 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1138  // Nothing to do if this block has no arguments or it's the entry block, which
1139  // always has the same arguments as the function signature.
1140  if (block->args_empty() || block->isEntryBlock())
1141  return success();
1142 
1143  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1144 
1145  // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1146  // A SPIR-V OpPhi instruction is of the syntax:
1147  // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1148  // So we need to collect all predecessor blocks and the arguments they send
1149  // to this block.
1151  for (Block *mlirPredecessor : block->getPredecessors()) {
1152  auto *terminator = mlirPredecessor->getTerminator();
1153  LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1154  LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1155  LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1156  // The predecessor here is the immediate one according to MLIR's IR
1157  // structure. It does not directly map to the incoming parent block for the
1158  // OpPhi instructions at SPIR-V binary level. This is because structured
1159  // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1160  // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1161  // the branch op jumping to the OpPhi's block then resides in the previous
1162  // structured control flow op's merge block.
1163  Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1164  LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1165  LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1166  if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1167  predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1168  } else if (auto branchCondOp =
1169  dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1170  std::optional<OperandRange> blockOperands;
1171  if (branchCondOp.getTrueTarget() == block) {
1172  blockOperands = branchCondOp.getTrueTargetOperands();
1173  } else {
1174  assert(branchCondOp.getFalseTarget() == block);
1175  blockOperands = branchCondOp.getFalseTargetOperands();
1176  }
1177 
1178  assert(!blockOperands->empty() &&
1179  "expected non-empty block operand range");
1180  predecessors.emplace_back(spirvPredecessor, *blockOperands);
1181  } else {
1182  return terminator->emitError("unimplemented terminator for Phi creation");
1183  }
1184  LLVM_DEBUG({
1185  llvm::dbgs() << " block arguments:\n";
1186  for (Value v : predecessors.back().second)
1187  llvm::dbgs() << " " << v << "\n";
1188  });
1189  }
1190 
1191  // Then create OpPhi instruction for each of the block argument.
1192  for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1193  BlockArgument arg = block->getArgument(argIndex);
1194 
1195  // Get the type <id> and result <id> for this OpPhi instruction.
1196  uint32_t phiTypeID = 0;
1197  if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1198  return failure();
1199  uint32_t phiID = getNextID();
1200 
1201  LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1202  << arg << " (id = " << phiID << ")\n");
1203 
1204  // Prepare the (value <id>, parent block <id>) pairs.
1205  SmallVector<uint32_t, 8> phiArgs;
1206  phiArgs.push_back(phiTypeID);
1207  phiArgs.push_back(phiID);
1208 
1209  for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1210  Value value = predecessors[predIndex].second[argIndex];
1211  uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1212  LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1213  << ") value " << value << ' ');
1214  // Each pair is a value <id> ...
1215  uint32_t valueId = getValueID(value);
1216  if (valueId == 0) {
1217  // The op generating this value hasn't been visited yet so we don't have
1218  // an <id> assigned yet. Record this to fix up later.
1219  LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1220  deferredPhiValues[value].push_back(functionBody.size() + 1 +
1221  phiArgs.size());
1222  } else {
1223  LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1224  }
1225  phiArgs.push_back(valueId);
1226  // ... and a parent block <id>.
1227  phiArgs.push_back(predBlockId);
1228  }
1229 
1230  encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1231  valueIDMap[arg] = phiID;
1232  }
1233 
1234  return success();
1235 }
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Operation
1239 //===----------------------------------------------------------------------===//
1240 
1241 LogicalResult Serializer::encodeExtensionInstruction(
1242  Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1243  ArrayRef<uint32_t> operands) {
1244  // Check if the extension has been imported.
1245  auto &setID = extendedInstSetIDMap[extensionSetName];
1246  if (!setID) {
1247  setID = getNextID();
1248  SmallVector<uint32_t, 16> importOperands;
1249  importOperands.push_back(setID);
1250  spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1251  encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1252  importOperands);
1253  }
1254 
1255  // The first two operands are the result type <id> and result <id>. The set
1256  // <id> and the opcode need to be insert after this.
1257  if (operands.size() < 2) {
1258  return op->emitError("extended instructions must have a result encoding");
1259  }
1260  SmallVector<uint32_t, 8> extInstOperands;
1261  extInstOperands.reserve(operands.size() + 2);
1262  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1263  extInstOperands.push_back(setID);
1264  extInstOperands.push_back(extensionOpcode);
1265  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1266  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1267  extInstOperands);
1268  return success();
1269 }
1270 
1271 LogicalResult Serializer::processOperation(Operation *opInst) {
1272  LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1273 
1274  // First dispatch the ops that do not directly mirror an instruction from
1275  // the SPIR-V spec.
1277  .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1278  .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1279  .Case([&](spirv::BranchConditionalOp op) {
1280  return processBranchConditionalOp(op);
1281  })
1282  .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1283  .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1284  .Case([&](spirv::GlobalVariableOp op) {
1285  return processGlobalVariableOp(op);
1286  })
1287  .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1288  .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1289  .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1290  .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1291  .Case([&](spirv::SpecConstantCompositeOp op) {
1292  return processSpecConstantCompositeOp(op);
1293  })
1294  .Case([&](spirv::SpecConstantOperationOp op) {
1295  return processSpecConstantOperationOp(op);
1296  })
1297  .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1298  .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1299 
1300  // Then handle all the ops that directly mirror SPIR-V instructions with
1301  // auto-generated methods.
1302  .Default(
1303  [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1304 }
1305 
1306 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1307  StringRef extInstSet,
1308  uint32_t opcode) {
1309  SmallVector<uint32_t, 4> operands;
1310  Location loc = op->getLoc();
1311 
1312  uint32_t resultID = 0;
1313  if (op->getNumResults() != 0) {
1314  uint32_t resultTypeID = 0;
1315  if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1316  return failure();
1317  operands.push_back(resultTypeID);
1318 
1319  resultID = getNextID();
1320  operands.push_back(resultID);
1321  valueIDMap[op->getResult(0)] = resultID;
1322  };
1323 
1324  for (Value operand : op->getOperands())
1325  operands.push_back(getValueID(operand));
1326 
1327  if (failed(emitDebugLine(functionBody, loc)))
1328  return failure();
1329 
1330  if (extInstSet.empty()) {
1331  encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1332  operands);
1333  } else {
1334  if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1335  return failure();
1336  }
1337 
1338  if (op->getNumResults() != 0) {
1339  for (auto attr : op->getAttrs()) {
1340  if (failed(processDecoration(loc, resultID, attr)))
1341  return failure();
1342  }
1343  }
1344 
1345  return success();
1346 }
1347 
1348 LogicalResult Serializer::emitDecoration(uint32_t target,
1349  spirv::Decoration decoration,
1350  ArrayRef<uint32_t> params) {
1351  uint32_t wordCount = 3 + params.size();
1352  llvm::append_values(
1353  decorations,
1354  spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1355  static_cast<uint32_t>(decoration));
1356  llvm::append_range(decorations, params);
1357  return success();
1358 }
1359 
1360 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1361  Location loc) {
1362  if (!options.emitDebugInfo)
1363  return success();
1364 
1365  if (lastProcessedWasMergeInst) {
1366  lastProcessedWasMergeInst = false;
1367  return success();
1368  }
1369 
1370  auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1371  if (fileLoc)
1372  encodeInstructionInto(binary, spirv::Opcode::OpLine,
1373  {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1374  return success();
1375 }
1376 } // namespace spirv
1377 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
Definition: Serializer.cpp:36
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Definition: Serializer.cpp:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Location getLoc() const
Return the location for this argument.
Definition: Value.h:324
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation & back()
Definition: Block.h:152
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:237
OpListType & getOperations()
Definition: Block.h:137
void print(raw_ostream &os)
bool args_empty()
Definition: Block.h:99
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
An attribute that represents a reference to a dense vector or tensor object.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:494
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Definition: Serializer.cpp:139
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
Definition: Serializer.cpp:85
LogicalResult serialize()
Serializes the remembered SPIR-V module.
Definition: Serializer.cpp:89
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
Definition: Serializer.cpp:113
SPIR-V struct type.
Definition: SPIRVTypes.h:293
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
Definition: Serializer.cpp:230
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Definition: Serializer.cpp:212
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
Definition: Serialization.h:27
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Definition: Serialization.h:29