MLIR  18.0.0git
Serializer.cpp
Go to the documentation of this file.
1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
27 #include <cstdint>
28 #include <optional>
29 
30 #define DEBUG_TYPE "spirv-serialization"
31 
32 using namespace mlir;
33 
34 /// Returns the merge block if the given `op` is a structured control flow op.
35 /// Otherwise returns nullptr.
37  if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38  return selectionOp.getMergeBlock();
39  if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40  return loopOp.getMergeBlock();
41  return nullptr;
42 }
43 
44 /// Given a predecessor `block` for a block with arguments, returns the block
45 /// that should be used as the parent block for SPIR-V OpPhi instructions
46 /// corresponding to the block arguments.
47 static Block *getPhiIncomingBlock(Block *block) {
48  // If the predecessor block in question is the entry block for a
49  // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50  if (block->isEntryBlock()) {
51  if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52  // Then the incoming parent block for OpPhi should be the merge block of
53  // the structured control flow op before this loop.
54  Operation *op = loopOp.getOperation();
55  while ((op = op->getPrevNode()) != nullptr)
56  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57  return incomingBlock;
58  // Or the enclosing block itself if no structured control flow ops
59  // exists before this loop.
60  return loopOp->getBlock();
61  }
62  }
63 
64  // Otherwise, we jump from the given predecessor block. Try to see if there is
65  // a structured control flow op inside it.
66  for (Operation &op : llvm::reverse(block->getOperations())) {
67  if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68  return incomingBlock;
69  }
70  return block;
71 }
72 
73 namespace mlir {
74 namespace spirv {
75 
76 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
77 /// the given `binary` vector.
78 void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
79  ArrayRef<uint32_t> operands) {
80  uint32_t wordCount = 1 + operands.size();
81  binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
82  binary.append(operands.begin(), operands.end());
83 }
84 
85 Serializer::Serializer(spirv::ModuleOp module,
87  : module(module), mlirBuilder(module.getContext()), options(options) {}
88 
90  LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
91 
92  if (failed(module.verifyInvariants()))
93  return failure();
94 
95  // TODO: handle the other sections
96  processCapability();
97  processExtension();
98  processMemoryModel();
99  processDebugInfo();
100 
101  // Iterate over the module body to serialize it. Assumptions are that there is
102  // only one basic block in the moduleOp
103  for (auto &op : *module.getBody()) {
104  if (failed(processOperation(&op))) {
105  return failure();
106  }
107  }
108 
109  LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
110  return success();
111 }
112 
114  auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
115  extensions.size() + extendedSets.size() +
116  memoryModel.size() + entryPoints.size() +
117  executionModes.size() + decorations.size() +
118  typesGlobalValues.size() + functions.size();
119 
120  binary.clear();
121  binary.reserve(moduleSize);
122 
123  spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
124  nextID);
125  binary.append(capabilities.begin(), capabilities.end());
126  binary.append(extensions.begin(), extensions.end());
127  binary.append(extendedSets.begin(), extendedSets.end());
128  binary.append(memoryModel.begin(), memoryModel.end());
129  binary.append(entryPoints.begin(), entryPoints.end());
130  binary.append(executionModes.begin(), executionModes.end());
131  binary.append(debug.begin(), debug.end());
132  binary.append(names.begin(), names.end());
133  binary.append(decorations.begin(), decorations.end());
134  binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135  binary.append(functions.begin(), functions.end());
136 }
137 
138 #ifndef NDEBUG
139 void Serializer::printValueIDMap(raw_ostream &os) {
140  os << "\n= Value <id> Map =\n\n";
141  for (auto valueIDPair : valueIDMap) {
142  Value val = valueIDPair.first;
143  os << " " << val << " "
144  << "id = " << valueIDPair.second << ' ';
145  if (auto *op = val.getDefiningOp()) {
146  os << "from op '" << op->getName() << "'";
147  } else if (auto arg = dyn_cast<BlockArgument>(val)) {
148  Block *block = arg.getOwner();
149  os << "from argument of block " << block << ' ';
150  os << " in op '" << block->getParentOp()->getName() << "'";
151  }
152  os << '\n';
153  }
154 }
155 #endif
156 
157 //===----------------------------------------------------------------------===//
158 // Module structure
159 //===----------------------------------------------------------------------===//
160 
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162  auto funcID = funcIDMap.lookup(fnName);
163  if (!funcID) {
164  funcID = getNextID();
165  funcIDMap[fnName] = funcID;
166  }
167  return funcID;
168 }
169 
170 void Serializer::processCapability() {
171  for (auto cap : module.getVceTriple()->getCapabilities())
172  encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
173  {static_cast<uint32_t>(cap)});
174 }
175 
176 void Serializer::processDebugInfo() {
177  if (!options.emitDebugInfo)
178  return;
179  auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180  auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
181  fileID = getNextID();
182  SmallVector<uint32_t, 16> operands;
183  operands.push_back(fileID);
184  spirv::encodeStringLiteralInto(operands, fileName);
185  encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
186  // TODO: Encode more debug instructions.
187 }
188 
189 void Serializer::processExtension() {
191  for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
192  extName.clear();
193  spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
194  encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
195  }
196 }
197 
198 void Serializer::processMemoryModel() {
199  auto mm = static_cast<uint32_t>(
200  module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
201  auto am = static_cast<uint32_t>(
202  module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
203  .getValue());
204 
205  encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
206 }
207 
208 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
209  NamedAttribute attr) {
210  auto attrName = attr.getName().strref();
211  auto decorationName =
212  llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
213  auto decoration = spirv::symbolizeDecoration(decorationName);
214  if (!decoration) {
215  return emitError(
216  loc, "non-argument attributes expected to have snake-case-ified "
217  "decoration name, unhandled attribute with name : ")
218  << attrName;
219  }
221  switch (*decoration) {
222  case spirv::Decoration::LinkageAttributes: {
223  // Get the value of the Linkage Attributes
224  // e.g., LinkageAttributes=["linkageName", linkageType].
225  auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.getValue());
226  auto linkageName = linkageAttr.getLinkageName();
227  auto linkageType = linkageAttr.getLinkageType().getValue();
228  // Encode the Linkage Name (string literal to uint32_t).
229  spirv::encodeStringLiteralInto(args, linkageName);
230  // Encode LinkageType & Add the Linkagetype to the args.
231  args.push_back(static_cast<uint32_t>(linkageType));
232  break;
233  }
234  case spirv::Decoration::Binding:
235  case spirv::Decoration::DescriptorSet:
236  case spirv::Decoration::Location:
237  if (auto intAttr = dyn_cast<IntegerAttr>(attr.getValue())) {
238  args.push_back(intAttr.getValue().getZExtValue());
239  break;
240  }
241  return emitError(loc, "expected integer attribute for ") << attrName;
242  case spirv::Decoration::BuiltIn:
243  if (auto strAttr = dyn_cast<StringAttr>(attr.getValue())) {
244  auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
245  if (enumVal) {
246  args.push_back(static_cast<uint32_t>(*enumVal));
247  break;
248  }
249  return emitError(loc, "invalid ")
250  << attrName << " attribute " << strAttr.getValue();
251  }
252  return emitError(loc, "expected string attribute for ") << attrName;
253  case spirv::Decoration::Aliased:
254  case spirv::Decoration::Flat:
255  case spirv::Decoration::NonReadable:
256  case spirv::Decoration::NonWritable:
257  case spirv::Decoration::NoPerspective:
258  case spirv::Decoration::Restrict:
259  case spirv::Decoration::RelaxedPrecision:
260  // For unit attributes, the args list has no values so we do nothing
261  if (auto unitAttr = dyn_cast<UnitAttr>(attr.getValue()))
262  break;
263  return emitError(loc, "expected unit attribute for ") << attrName;
264  default:
265  return emitError(loc, "unhandled decoration ") << decorationName;
266  }
267  return emitDecoration(resultID, *decoration, args);
268 }
269 
270 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
271  assert(!name.empty() && "unexpected empty string for OpName");
272  if (!options.emitSymbolName)
273  return success();
274 
275  SmallVector<uint32_t, 4> nameOperands;
276  nameOperands.push_back(resultID);
277  spirv::encodeStringLiteralInto(nameOperands, name);
278  encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
279  return success();
280 }
281 
282 template <>
283 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
284  Location loc, spirv::ArrayType type, uint32_t resultID) {
285  if (unsigned stride = type.getArrayStride()) {
286  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
287  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
288  }
289  return success();
290 }
291 
292 template <>
293 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
294  Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
295  if (unsigned stride = type.getArrayStride()) {
296  // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
297  return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
298  }
299  return success();
300 }
301 
302 LogicalResult Serializer::processMemberDecoration(
303  uint32_t structID,
304  const spirv::StructType::MemberDecorationInfo &memberDecoration) {
306  {structID, memberDecoration.memberIndex,
307  static_cast<uint32_t>(memberDecoration.decoration)});
308  if (memberDecoration.hasValue) {
309  args.push_back(memberDecoration.decorationValue);
310  }
311  encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
312  return success();
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Type
317 //===----------------------------------------------------------------------===//
318 
319 // According to the SPIR-V spec "Validation Rules for Shader Capabilities":
320 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
321 // PushConstant Storage Classes must be explicitly laid out."
322 bool Serializer::isInterfaceStructPtrType(Type type) const {
323  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
324  switch (ptrType.getStorageClass()) {
325  case spirv::StorageClass::PhysicalStorageBuffer:
326  case spirv::StorageClass::PushConstant:
327  case spirv::StorageClass::StorageBuffer:
328  case spirv::StorageClass::Uniform:
329  return isa<spirv::StructType>(ptrType.getPointeeType());
330  default:
331  break;
332  }
333  }
334  return false;
335 }
336 
337 LogicalResult Serializer::processType(Location loc, Type type,
338  uint32_t &typeID) {
339  // Maintains a set of names for nested identified struct types. This is used
340  // to properly serialize recursive references.
341  SetVector<StringRef> serializationCtx;
342  return processTypeImpl(loc, type, typeID, serializationCtx);
343 }
344 
346 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
347  SetVector<StringRef> &serializationCtx) {
348  typeID = getTypeID(type);
349  if (typeID)
350  return success();
351 
352  typeID = getNextID();
353  SmallVector<uint32_t, 4> operands;
354 
355  operands.push_back(typeID);
356  auto typeEnum = spirv::Opcode::OpTypeVoid;
357  bool deferSerialization = false;
358 
359  if ((isa<FunctionType>(type) &&
360  succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
361  operands))) ||
362  succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
363  deferSerialization, serializationCtx))) {
364  if (deferSerialization)
365  return success();
366 
367  typeIDMap[type] = typeID;
368 
369  encodeInstructionInto(typesGlobalValues, typeEnum, operands);
370 
371  if (recursiveStructInfos.count(type) != 0) {
372  // This recursive struct type is emitted already, now the OpTypePointer
373  // instructions referring to recursive references are emitted as well.
374  for (auto &ptrInfo : recursiveStructInfos[type]) {
375  // TODO: This might not work if more than 1 recursive reference is
376  // present in the struct.
377  SmallVector<uint32_t, 4> ptrOperands;
378  ptrOperands.push_back(ptrInfo.pointerTypeID);
379  ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
380  ptrOperands.push_back(typeIDMap[type]);
381 
382  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
383  ptrOperands);
384  }
385 
386  recursiveStructInfos[type].clear();
387  }
388 
389  return success();
390  }
391 
392  return failure();
393 }
394 
395 LogicalResult Serializer::prepareBasicType(
396  Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
397  SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
398  SetVector<StringRef> &serializationCtx) {
399  deferSerialization = false;
400 
401  if (isVoidType(type)) {
402  typeEnum = spirv::Opcode::OpTypeVoid;
403  return success();
404  }
405 
406  if (auto intType = dyn_cast<IntegerType>(type)) {
407  if (intType.getWidth() == 1) {
408  typeEnum = spirv::Opcode::OpTypeBool;
409  return success();
410  }
411 
412  typeEnum = spirv::Opcode::OpTypeInt;
413  operands.push_back(intType.getWidth());
414  // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
415  // to preserve or validate.
416  // 0 indicates unsigned, or no signedness semantics
417  // 1 indicates signed semantics."
418  operands.push_back(intType.isSigned() ? 1 : 0);
419  return success();
420  }
421 
422  if (auto floatType = dyn_cast<FloatType>(type)) {
423  typeEnum = spirv::Opcode::OpTypeFloat;
424  operands.push_back(floatType.getWidth());
425  return success();
426  }
427 
428  if (auto vectorType = dyn_cast<VectorType>(type)) {
429  uint32_t elementTypeID = 0;
430  if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
431  serializationCtx))) {
432  return failure();
433  }
434  typeEnum = spirv::Opcode::OpTypeVector;
435  operands.push_back(elementTypeID);
436  operands.push_back(vectorType.getNumElements());
437  return success();
438  }
439 
440  if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
441  typeEnum = spirv::Opcode::OpTypeImage;
442  uint32_t sampledTypeID = 0;
443  if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
444  return failure();
445 
446  operands.push_back(sampledTypeID);
447  operands.push_back(static_cast<uint32_t>(imageType.getDim()));
448  operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo()));
449  operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo()));
450  operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo()));
451  operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo()));
452  operands.push_back(static_cast<uint32_t>(imageType.getImageFormat()));
453  return success();
454  }
455 
456  if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
457  typeEnum = spirv::Opcode::OpTypeArray;
458  uint32_t elementTypeID = 0;
459  if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
460  serializationCtx))) {
461  return failure();
462  }
463  operands.push_back(elementTypeID);
464  if (auto elementCountID = prepareConstantInt(
465  loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
466  operands.push_back(elementCountID);
467  }
468  return processTypeDecoration(loc, arrayType, resultID);
469  }
470 
471  if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
472  uint32_t pointeeTypeID = 0;
473  spirv::StructType pointeeStruct =
474  dyn_cast<spirv::StructType>(ptrType.getPointeeType());
475 
476  if (pointeeStruct && pointeeStruct.isIdentified() &&
477  serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
478  // A recursive reference to an enclosing struct is found.
479  //
480  // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
481  // class as operands.
482  SmallVector<uint32_t, 2> forwardPtrOperands;
483  forwardPtrOperands.push_back(resultID);
484  forwardPtrOperands.push_back(
485  static_cast<uint32_t>(ptrType.getStorageClass()));
486 
487  encodeInstructionInto(typesGlobalValues,
488  spirv::Opcode::OpTypeForwardPointer,
489  forwardPtrOperands);
490 
491  // 2. Find the pointee (enclosing) struct.
492  auto structType = spirv::StructType::getIdentified(
493  module.getContext(), pointeeStruct.getIdentifier());
494 
495  if (!structType)
496  return failure();
497 
498  // 3. Mark the OpTypePointer that is supposed to be emitted by this call
499  // as deferred.
500  deferSerialization = true;
501 
502  // 4. Record the info needed to emit the deferred OpTypePointer
503  // instruction when the enclosing struct is completely serialized.
504  recursiveStructInfos[structType].push_back(
505  {resultID, ptrType.getStorageClass()});
506  } else {
507  if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
508  serializationCtx)))
509  return failure();
510  }
511 
512  typeEnum = spirv::Opcode::OpTypePointer;
513  operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
514  operands.push_back(pointeeTypeID);
515 
516  if (isInterfaceStructPtrType(ptrType)) {
517  if (failed(emitDecoration(getTypeID(pointeeStruct),
518  spirv::Decoration::Block)))
519  return emitError(loc, "cannot decorate ")
520  << pointeeStruct << " with Block decoration";
521  }
522 
523  return success();
524  }
525 
526  if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
527  uint32_t elementTypeID = 0;
528  if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
529  elementTypeID, serializationCtx))) {
530  return failure();
531  }
532  typeEnum = spirv::Opcode::OpTypeRuntimeArray;
533  operands.push_back(elementTypeID);
534  return processTypeDecoration(loc, runtimeArrayType, resultID);
535  }
536 
537  if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
538  typeEnum = spirv::Opcode::OpTypeSampledImage;
539  uint32_t imageTypeID = 0;
540  if (failed(
541  processType(loc, sampledImageType.getImageType(), imageTypeID))) {
542  return failure();
543  }
544  operands.push_back(imageTypeID);
545  return success();
546  }
547 
548  if (auto structType = dyn_cast<spirv::StructType>(type)) {
549  if (structType.isIdentified()) {
550  if (failed(processName(resultID, structType.getIdentifier())))
551  return failure();
552  serializationCtx.insert(structType.getIdentifier());
553  }
554 
555  bool hasOffset = structType.hasOffset();
556  for (auto elementIndex :
557  llvm::seq<uint32_t>(0, structType.getNumElements())) {
558  uint32_t elementTypeID = 0;
559  if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
560  elementTypeID, serializationCtx))) {
561  return failure();
562  }
563  operands.push_back(elementTypeID);
564  if (hasOffset) {
565  // Decorate each struct member with an offset
567  elementIndex, /*hasValue=*/1, spirv::Decoration::Offset,
568  static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
569  if (failed(processMemberDecoration(resultID, offsetDecoration))) {
570  return emitError(loc, "cannot decorate ")
571  << elementIndex << "-th member of " << structType
572  << " with its offset";
573  }
574  }
575  }
577  structType.getMemberDecorations(memberDecorations);
578 
579  for (auto &memberDecoration : memberDecorations) {
580  if (failed(processMemberDecoration(resultID, memberDecoration))) {
581  return emitError(loc, "cannot decorate ")
582  << static_cast<uint32_t>(memberDecoration.memberIndex)
583  << "-th member of " << structType << " with "
584  << stringifyDecoration(memberDecoration.decoration);
585  }
586  }
587 
588  typeEnum = spirv::Opcode::OpTypeStruct;
589 
590  if (structType.isIdentified())
591  serializationCtx.remove(structType.getIdentifier());
592 
593  return success();
594  }
595 
596  if (auto cooperativeMatrixType =
597  dyn_cast<spirv::CooperativeMatrixType>(type)) {
598  uint32_t elementTypeID = 0;
599  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
600  elementTypeID, serializationCtx))) {
601  return failure();
602  }
603  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
604  auto getConstantOp = [&](uint32_t id) {
605  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
606  return prepareConstantInt(loc, attr);
607  };
608  operands.push_back(elementTypeID);
609  operands.push_back(
610  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
611  operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
612  operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
613  operands.push_back(
614  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
615  return success();
616  }
617 
618  if (auto cooperativeMatrixType =
619  dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
620  uint32_t elementTypeID = 0;
621  if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
622  elementTypeID, serializationCtx))) {
623  return failure();
624  }
625  typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
626  auto getConstantOp = [&](uint32_t id) {
627  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
628  return prepareConstantInt(loc, attr);
629  };
630  operands.push_back(elementTypeID);
631  operands.push_back(
632  getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
633  operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
634  operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
635  return success();
636  }
637 
638  if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
639  uint32_t elementTypeID = 0;
640  if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
641  elementTypeID, serializationCtx))) {
642  return failure();
643  }
644  typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
645  auto getConstantOp = [&](uint32_t id) {
646  auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
647  return prepareConstantInt(loc, attr);
648  };
649  operands.push_back(elementTypeID);
650  operands.push_back(getConstantOp(jointMatrixType.getRows()));
651  operands.push_back(getConstantOp(jointMatrixType.getColumns()));
652  operands.push_back(getConstantOp(
653  static_cast<uint32_t>(jointMatrixType.getMatrixLayout())));
654  operands.push_back(
655  getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
656  return success();
657  }
658 
659  if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
660  uint32_t elementTypeID = 0;
661  if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
662  serializationCtx))) {
663  return failure();
664  }
665  typeEnum = spirv::Opcode::OpTypeMatrix;
666  operands.push_back(elementTypeID);
667  operands.push_back(matrixType.getNumColumns());
668  return success();
669  }
670 
671  // TODO: Handle other types.
672  return emitError(loc, "unhandled type in serialization: ") << type;
673 }
674 
676 Serializer::prepareFunctionType(Location loc, FunctionType type,
677  spirv::Opcode &typeEnum,
678  SmallVectorImpl<uint32_t> &operands) {
679  typeEnum = spirv::Opcode::OpTypeFunction;
680  assert(type.getNumResults() <= 1 &&
681  "serialization supports only a single return value");
682  uint32_t resultID = 0;
683  if (failed(processType(
684  loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
685  resultID))) {
686  return failure();
687  }
688  operands.push_back(resultID);
689  for (auto &res : type.getInputs()) {
690  uint32_t argTypeID = 0;
691  if (failed(processType(loc, res, argTypeID))) {
692  return failure();
693  }
694  operands.push_back(argTypeID);
695  }
696  return success();
697 }
698 
699 //===----------------------------------------------------------------------===//
700 // Constant
701 //===----------------------------------------------------------------------===//
702 
703 uint32_t Serializer::prepareConstant(Location loc, Type constType,
704  Attribute valueAttr) {
705  if (auto id = prepareConstantScalar(loc, valueAttr)) {
706  return id;
707  }
708 
709  // This is a composite literal. We need to handle each component separately
710  // and then emit an OpConstantComposite for the whole.
711 
712  if (auto id = getConstantID(valueAttr)) {
713  return id;
714  }
715 
716  uint32_t typeID = 0;
717  if (failed(processType(loc, constType, typeID))) {
718  return 0;
719  }
720 
721  uint32_t resultID = 0;
722  if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
723  int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
724  SmallVector<uint64_t, 4> index(rank);
725  resultID = prepareDenseElementsConstant(loc, constType, attr,
726  /*dim=*/0, index);
727  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
728  resultID = prepareArrayConstant(loc, constType, arrayAttr);
729  }
730 
731  if (resultID == 0) {
732  emitError(loc, "cannot serialize attribute: ") << valueAttr;
733  return 0;
734  }
735 
736  constIDMap[valueAttr] = resultID;
737  return resultID;
738 }
739 
740 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
741  ArrayAttr attr) {
742  uint32_t typeID = 0;
743  if (failed(processType(loc, constType, typeID))) {
744  return 0;
745  }
746 
747  uint32_t resultID = getNextID();
748  SmallVector<uint32_t, 4> operands = {typeID, resultID};
749  operands.reserve(attr.size() + 2);
750  auto elementType = cast<spirv::ArrayType>(constType).getElementType();
751  for (Attribute elementAttr : attr) {
752  if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
753  operands.push_back(elementID);
754  } else {
755  return 0;
756  }
757  }
758  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
759  encodeInstructionInto(typesGlobalValues, opcode, operands);
760 
761  return resultID;
762 }
763 
764 // TODO: Turn the below function into iterative function, instead of
765 // recursive function.
766 uint32_t
767 Serializer::prepareDenseElementsConstant(Location loc, Type constType,
768  DenseElementsAttr valueAttr, int dim,
770  auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
771  assert(dim <= shapedType.getRank());
772  if (shapedType.getRank() == dim) {
773  if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
774  return attr.getType().getElementType().isInteger(1)
775  ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
776  : prepareConstantInt(loc,
777  attr.getValues<IntegerAttr>()[index]);
778  }
779  if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
780  return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
781  }
782  return 0;
783  }
784 
785  uint32_t typeID = 0;
786  if (failed(processType(loc, constType, typeID))) {
787  return 0;
788  }
789 
790  uint32_t resultID = getNextID();
791  SmallVector<uint32_t, 4> operands = {typeID, resultID};
792  operands.reserve(shapedType.getDimSize(dim) + 2);
793  auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
794  for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
795  index[dim] = i;
796  if (auto elementID = prepareDenseElementsConstant(
797  loc, elementType, valueAttr, dim + 1, index)) {
798  operands.push_back(elementID);
799  } else {
800  return 0;
801  }
802  }
803  spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
804  encodeInstructionInto(typesGlobalValues, opcode, operands);
805 
806  return resultID;
807 }
808 
809 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
810  bool isSpec) {
811  if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
812  return prepareConstantFp(loc, floatAttr, isSpec);
813  }
814  if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
815  return prepareConstantBool(loc, boolAttr, isSpec);
816  }
817  if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
818  return prepareConstantInt(loc, intAttr, isSpec);
819  }
820 
821  return 0;
822 }
823 
824 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
825  bool isSpec) {
826  if (!isSpec) {
827  // We can de-duplicate normal constants, but not specialization constants.
828  if (auto id = getConstantID(boolAttr)) {
829  return id;
830  }
831  }
832 
833  // Process the type for this bool literal
834  uint32_t typeID = 0;
835  if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
836  return 0;
837  }
838 
839  auto resultID = getNextID();
840  auto opcode = boolAttr.getValue()
841  ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
842  : spirv::Opcode::OpConstantTrue)
843  : (isSpec ? spirv::Opcode::OpSpecConstantFalse
844  : spirv::Opcode::OpConstantFalse);
845  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
846 
847  if (!isSpec) {
848  constIDMap[boolAttr] = resultID;
849  }
850  return resultID;
851 }
852 
853 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
854  bool isSpec) {
855  if (!isSpec) {
856  // We can de-duplicate normal constants, but not specialization constants.
857  if (auto id = getConstantID(intAttr)) {
858  return id;
859  }
860  }
861 
862  // Process the type for this integer literal
863  uint32_t typeID = 0;
864  if (failed(processType(loc, intAttr.getType(), typeID))) {
865  return 0;
866  }
867 
868  auto resultID = getNextID();
869  APInt value = intAttr.getValue();
870  unsigned bitwidth = value.getBitWidth();
871  bool isSigned = intAttr.getType().isSignedInteger();
872  auto opcode =
873  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
874 
875  switch (bitwidth) {
876  // According to SPIR-V spec, "When the type's bit width is less than
877  // 32-bits, the literal's value appears in the low-order bits of the word,
878  // and the high-order bits must be 0 for a floating-point type, or 0 for an
879  // integer type with Signedness of 0, or sign extended when Signedness
880  // is 1."
881  case 32:
882  case 16:
883  case 8: {
884  uint32_t word = 0;
885  if (isSigned) {
886  word = static_cast<int32_t>(value.getSExtValue());
887  } else {
888  word = static_cast<uint32_t>(value.getZExtValue());
889  }
890  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
891  } break;
892  // According to SPIR-V spec: "When the type's bit width is larger than one
893  // word, the literal’s low-order words appear first."
894  case 64: {
895  struct DoubleWord {
896  uint32_t word1;
897  uint32_t word2;
898  } words;
899  if (isSigned) {
900  words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
901  } else {
902  words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
903  }
904  encodeInstructionInto(typesGlobalValues, opcode,
905  {typeID, resultID, words.word1, words.word2});
906  } break;
907  default: {
908  std::string valueStr;
909  llvm::raw_string_ostream rss(valueStr);
910  value.print(rss, /*isSigned=*/false);
911 
912  emitError(loc, "cannot serialize ")
913  << bitwidth << "-bit integer literal: " << rss.str();
914  return 0;
915  }
916  }
917 
918  if (!isSpec) {
919  constIDMap[intAttr] = resultID;
920  }
921  return resultID;
922 }
923 
924 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
925  bool isSpec) {
926  if (!isSpec) {
927  // We can de-duplicate normal constants, but not specialization constants.
928  if (auto id = getConstantID(floatAttr)) {
929  return id;
930  }
931  }
932 
933  // Process the type for this float literal
934  uint32_t typeID = 0;
935  if (failed(processType(loc, floatAttr.getType(), typeID))) {
936  return 0;
937  }
938 
939  auto resultID = getNextID();
940  APFloat value = floatAttr.getValue();
941  APInt intValue = value.bitcastToAPInt();
942 
943  auto opcode =
944  isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
945 
946  if (&value.getSemantics() == &APFloat::IEEEsingle()) {
947  uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
948  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
949  } else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
950  struct DoubleWord {
951  uint32_t word1;
952  uint32_t word2;
953  } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
954  encodeInstructionInto(typesGlobalValues, opcode,
955  {typeID, resultID, words.word1, words.word2});
956  } else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
957  uint32_t word =
958  static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
959  encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
960  } else {
961  std::string valueStr;
962  llvm::raw_string_ostream rss(valueStr);
963  value.print(rss);
964 
965  emitError(loc, "cannot serialize ")
966  << floatAttr.getType() << "-typed float literal: " << rss.str();
967  return 0;
968  }
969 
970  if (!isSpec) {
971  constIDMap[floatAttr] = resultID;
972  }
973  return resultID;
974 }
975 
976 //===----------------------------------------------------------------------===//
977 // Control flow
978 //===----------------------------------------------------------------------===//
979 
980 uint32_t Serializer::getOrCreateBlockID(Block *block) {
981  if (uint32_t id = getBlockID(block))
982  return id;
983  return blockIDMap[block] = getNextID();
984 }
985 
986 #ifndef NDEBUG
987 void Serializer::printBlock(Block *block, raw_ostream &os) {
988  os << "block " << block << " (id = ";
989  if (uint32_t id = getBlockID(block))
990  os << id;
991  else
992  os << "unknown";
993  os << ")\n";
994 }
995 #endif
996 
998 Serializer::processBlock(Block *block, bool omitLabel,
999  function_ref<LogicalResult()> emitMerge) {
1000  LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1001  LLVM_DEBUG(block->print(llvm::dbgs()));
1002  LLVM_DEBUG(llvm::dbgs() << '\n');
1003  if (!omitLabel) {
1004  uint32_t blockID = getOrCreateBlockID(block);
1005  LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1006 
1007  // Emit OpLabel for this block.
1008  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1009  }
1010 
1011  // Emit OpPhi instructions for block arguments, if any.
1012  if (failed(emitPhiForBlockArguments(block)))
1013  return failure();
1014 
1015  // If we need to emit merge instructions, it must happen in this block. Check
1016  // whether we have other structured control flow ops, which will be expanded
1017  // into multiple basic blocks. If that's the case, we need to emit the merge
1018  // right now and then create new blocks for further serialization of the ops
1019  // in this block.
1020  if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) {
1021  return isa<spirv::LoopOp, spirv::SelectionOp>(op);
1022  })) {
1023  if (failed(emitMerge()))
1024  return failure();
1025  emitMerge = nullptr;
1026 
1027  // Start a new block for further serialization.
1028  uint32_t blockID = getNextID();
1029  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1030  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1031  }
1032 
1033  // Process each op in this block except the terminator.
1034  for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) {
1035  if (failed(processOperation(&op)))
1036  return failure();
1037  }
1038 
1039  // Process the terminator.
1040  if (emitMerge)
1041  if (failed(emitMerge()))
1042  return failure();
1043  if (failed(processOperation(&block->back())))
1044  return failure();
1045 
1046  return success();
1047 }
1048 
1049 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1050  // Nothing to do if this block has no arguments or it's the entry block, which
1051  // always has the same arguments as the function signature.
1052  if (block->args_empty() || block->isEntryBlock())
1053  return success();
1054 
1055  LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1056 
1057  // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1058  // A SPIR-V OpPhi instruction is of the syntax:
1059  // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1060  // So we need to collect all predecessor blocks and the arguments they send
1061  // to this block.
1063  for (Block *mlirPredecessor : block->getPredecessors()) {
1064  auto *terminator = mlirPredecessor->getTerminator();
1065  LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1066  LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1067  LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1068  // The predecessor here is the immediate one according to MLIR's IR
1069  // structure. It does not directly map to the incoming parent block for the
1070  // OpPhi instructions at SPIR-V binary level. This is because structured
1071  // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1072  // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1073  // the branch op jumping to the OpPhi's block then resides in the previous
1074  // structured control flow op's merge block.
1075  Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1076  LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1077  LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1078  if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1079  predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1080  } else if (auto branchCondOp =
1081  dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1082  std::optional<OperandRange> blockOperands;
1083  if (branchCondOp.getTrueTarget() == block) {
1084  blockOperands = branchCondOp.getTrueTargetOperands();
1085  } else {
1086  assert(branchCondOp.getFalseTarget() == block);
1087  blockOperands = branchCondOp.getFalseTargetOperands();
1088  }
1089 
1090  assert(!blockOperands->empty() &&
1091  "expected non-empty block operand range");
1092  predecessors.emplace_back(spirvPredecessor, *blockOperands);
1093  } else {
1094  return terminator->emitError("unimplemented terminator for Phi creation");
1095  }
1096  LLVM_DEBUG({
1097  llvm::dbgs() << " block arguments:\n";
1098  for (Value v : predecessors.back().second)
1099  llvm::dbgs() << " " << v << "\n";
1100  });
1101  }
1102 
1103  // Then create OpPhi instruction for each of the block argument.
1104  for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1105  BlockArgument arg = block->getArgument(argIndex);
1106 
1107  // Get the type <id> and result <id> for this OpPhi instruction.
1108  uint32_t phiTypeID = 0;
1109  if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1110  return failure();
1111  uint32_t phiID = getNextID();
1112 
1113  LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1114  << arg << " (id = " << phiID << ")\n");
1115 
1116  // Prepare the (value <id>, parent block <id>) pairs.
1117  SmallVector<uint32_t, 8> phiArgs;
1118  phiArgs.push_back(phiTypeID);
1119  phiArgs.push_back(phiID);
1120 
1121  for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1122  Value value = predecessors[predIndex].second[argIndex];
1123  uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1124  LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1125  << ") value " << value << ' ');
1126  // Each pair is a value <id> ...
1127  uint32_t valueId = getValueID(value);
1128  if (valueId == 0) {
1129  // The op generating this value hasn't been visited yet so we don't have
1130  // an <id> assigned yet. Record this to fix up later.
1131  LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1132  deferredPhiValues[value].push_back(functionBody.size() + 1 +
1133  phiArgs.size());
1134  } else {
1135  LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1136  }
1137  phiArgs.push_back(valueId);
1138  // ... and a parent block <id>.
1139  phiArgs.push_back(predBlockId);
1140  }
1141 
1142  encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1143  valueIDMap[arg] = phiID;
1144  }
1145 
1146  return success();
1147 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // Operation
1151 //===----------------------------------------------------------------------===//
1152 
1153 LogicalResult Serializer::encodeExtensionInstruction(
1154  Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1155  ArrayRef<uint32_t> operands) {
1156  // Check if the extension has been imported.
1157  auto &setID = extendedInstSetIDMap[extensionSetName];
1158  if (!setID) {
1159  setID = getNextID();
1160  SmallVector<uint32_t, 16> importOperands;
1161  importOperands.push_back(setID);
1162  spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1163  encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1164  importOperands);
1165  }
1166 
1167  // The first two operands are the result type <id> and result <id>. The set
1168  // <id> and the opcode need to be insert after this.
1169  if (operands.size() < 2) {
1170  return op->emitError("extended instructions must have a result encoding");
1171  }
1172  SmallVector<uint32_t, 8> extInstOperands;
1173  extInstOperands.reserve(operands.size() + 2);
1174  extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1175  extInstOperands.push_back(setID);
1176  extInstOperands.push_back(extensionOpcode);
1177  extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1178  encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1179  extInstOperands);
1180  return success();
1181 }
1182 
1183 LogicalResult Serializer::processOperation(Operation *opInst) {
1184  LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1185 
1186  // First dispatch the ops that do not directly mirror an instruction from
1187  // the SPIR-V spec.
1189  .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1190  .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1191  .Case([&](spirv::BranchConditionalOp op) {
1192  return processBranchConditionalOp(op);
1193  })
1194  .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1195  .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1196  .Case([&](spirv::GlobalVariableOp op) {
1197  return processGlobalVariableOp(op);
1198  })
1199  .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1200  .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1201  .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1202  .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1203  .Case([&](spirv::SpecConstantCompositeOp op) {
1204  return processSpecConstantCompositeOp(op);
1205  })
1206  .Case([&](spirv::SpecConstantOperationOp op) {
1207  return processSpecConstantOperationOp(op);
1208  })
1209  .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1210  .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1211 
1212  // Then handle all the ops that directly mirror SPIR-V instructions with
1213  // auto-generated methods.
1214  .Default(
1215  [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1216 }
1217 
1218 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1219  StringRef extInstSet,
1220  uint32_t opcode) {
1221  SmallVector<uint32_t, 4> operands;
1222  Location loc = op->getLoc();
1223 
1224  uint32_t resultID = 0;
1225  if (op->getNumResults() != 0) {
1226  uint32_t resultTypeID = 0;
1227  if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1228  return failure();
1229  operands.push_back(resultTypeID);
1230 
1231  resultID = getNextID();
1232  operands.push_back(resultID);
1233  valueIDMap[op->getResult(0)] = resultID;
1234  };
1235 
1236  for (Value operand : op->getOperands())
1237  operands.push_back(getValueID(operand));
1238 
1239  if (failed(emitDebugLine(functionBody, loc)))
1240  return failure();
1241 
1242  if (extInstSet.empty()) {
1243  encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1244  operands);
1245  } else {
1246  if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1247  return failure();
1248  }
1249 
1250  if (op->getNumResults() != 0) {
1251  for (auto attr : op->getAttrs()) {
1252  if (failed(processDecoration(loc, resultID, attr)))
1253  return failure();
1254  }
1255  }
1256 
1257  return success();
1258 }
1259 
1260 LogicalResult Serializer::emitDecoration(uint32_t target,
1261  spirv::Decoration decoration,
1262  ArrayRef<uint32_t> params) {
1263  uint32_t wordCount = 3 + params.size();
1264  decorations.push_back(
1265  spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate));
1266  decorations.push_back(target);
1267  decorations.push_back(static_cast<uint32_t>(decoration));
1268  decorations.append(params.begin(), params.end());
1269  return success();
1270 }
1271 
1272 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1273  Location loc) {
1274  if (!options.emitDebugInfo)
1275  return success();
1276 
1277  if (lastProcessedWasMergeInst) {
1278  lastProcessedWasMergeInst = false;
1279  return success();
1280  }
1281 
1282  auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1283  if (fileLoc)
1284  encodeInstructionInto(binary, spirv::Opcode::OpLine,
1285  {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1286  return success();
1287 }
1288 } // namespace spirv
1289 } // namespace mlir
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
Definition: Serializer.cpp:36
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Definition: Serializer.cpp:47
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:310
Location getLoc() const
Return the location for this argument.
Definition: Value.h:325
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:122
unsigned getNumArguments()
Definition: Block.h:121
Operation & back()
Definition: Block.h:145
iterator_range< pred_iterator > getPredecessors()
Definition: Block.h:230
OpListType & getOperations()
Definition: Block.h:130
void print(raw_ostream &os)
bool args_empty()
Definition: Block.h:92
iterator end()
Definition: Block.h:137
iterator begin()
Definition: Block.h:136
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:30
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:216
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:198
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:49
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:212
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:486
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:68
unsigned getArrayStride() const
Returns the array stride in bytes.
Definition: SPIRVTypes.cpp:613
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Definition: Serializer.cpp:139
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
Definition: Serializer.cpp:85
LogicalResult serialize()
Serializes the remembered SPIR-V module.
Definition: Serializer.cpp:89
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
Definition: Serializer.cpp:113
SPIR-V struct type.
Definition: SPIRVTypes.h:284
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
Definition: Serialization.h:27
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Definition: Serialization.h:29