MLIR 23.0.0git
Serializer.cpp
Go to the documentation of this file.
1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73static bool isZeroValue(Attribute attr) {
74 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
76 }
77 if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
79 }
80 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
82 }
83 if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
84 return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
85 }
86 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
87 return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
88 }
89 return false;
90}
91
92/// Move all functions declaration before functions definitions. In SPIR-V
93/// "declarations" are functions without a body and "definitions" functions
94/// with a body. This is stronger than necessary. It should be sufficient to
95/// ensure any declarations precede their uses and not all definitions, however
96/// this allows to avoid analysing every function in the module this way.
97static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
98 Block::OpListType &ops = moduleOp.getBody()->getOperations();
99 if (ops.empty())
100 return;
101 Operation &firstOp = ops.front();
102 for (Operation &op : llvm::drop_begin(ops))
103 if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
105 funcOp->moveBefore(&firstOp);
106}
107
108namespace mlir {
109namespace spirv {
110
111/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
112/// the given `binary` vector.
114 ArrayRef<uint32_t> operands) {
115 uint32_t wordCount = 1 + operands.size();
116 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
117 binary.append(operands.begin(), operands.end());
118}
119
120Serializer::Serializer(spirv::ModuleOp module,
121 const SerializationOptions &options)
122 : module(module), mlirBuilder(module.getContext()), options(options) {}
123
124LogicalResult Serializer::serialize() {
125 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
126
127 if (failed(module.verifyInvariants()))
128 return failure();
129
130 // TODO: handle the other sections
131 processCapability();
132 if (failed(processExtension())) {
133 return failure();
134 }
135 processMemoryModel();
136 processDebugInfo();
137
139
140 // Iterate over the module body to serialize it. Assumptions are that there is
141 // only one basic block in the moduleOp
142 for (auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
144 return failure();
145 }
146 }
147
148 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
149 return success();
150}
151
153 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
158
159 binary.clear();
160 binary.reserve(moduleSize);
161
162 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
163 nextID);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176 binary.append(graphsDebugInfo.begin(), graphsDebugInfo.end());
177}
178
179#ifndef NDEBUG
181 os << "\n= Value <id> Map =\n\n";
182 for (auto valueIDPair : valueIDMap) {
183 Value val = valueIDPair.first;
184 os << " " << val << " "
185 << "id = " << valueIDPair.second << ' ';
186 if (auto *op = val.getDefiningOp()) {
187 os << "from op '" << op->getName() << "'";
188 } else if (auto arg = dyn_cast<BlockArgument>(val)) {
189 Block *block = arg.getOwner();
190 os << "from argument of block " << block << ' ';
191 os << " in op '" << block->getParentOp()->getName() << "'";
192 }
193 os << '\n';
194 }
195}
196#endif
197
198//===----------------------------------------------------------------------===//
199// Module structure
200//===----------------------------------------------------------------------===//
201
202uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
203 auto funcID = funcIDMap.lookup(fnName);
204 if (!funcID) {
205 funcID = getNextID();
206 funcIDMap[fnName] = funcID;
207 }
208 return funcID;
209}
210
211void Serializer::processCapability() {
212 for (auto cap : module.getVceTriple()->getCapabilities())
213 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
214 {static_cast<uint32_t>(cap)});
215}
216
217void Serializer::addLongCompositesCapability() {
218 if (longCompositesEmitted)
219 return;
220 longCompositesEmitted = true;
221 auto vceTriple = module.getVceTriple();
222 if (!llvm::is_contained(vceTriple->getCapabilities(),
223 spirv::Capability::LongCompositesINTEL))
225 capabilities, spirv::Opcode::OpCapability,
226 {static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL)});
227 if (!llvm::is_contained(vceTriple->getExtensions(),
228 spirv::Extension::SPV_INTEL_long_composites)) {
229 SmallVector<uint32_t, 8> extName;
231 extName,
232 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
233 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
234 }
235}
236
237void Serializer::encodeInstructionWithContinuationInto(
238 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
239 ArrayRef<uint32_t> operands) {
240 if (1 + operands.size() <= spirv::kMaxWordCount) {
241 encodeInstructionInto(binary, op, operands);
242 return;
243 }
244
245 std::optional<spirv::Opcode> continuationOp =
247 assert(continuationOp && "op is not a splittable composite/struct opcode");
248
249 const unsigned chunk = spirv::kMaxWordCount - 1;
250 encodeInstructionInto(binary, op, operands.take_front(chunk));
251 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
252 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
253 encodeInstructionInto(binary, *continuationOp, rest.take_front(chunk));
254 }
255
256 addLongCompositesCapability();
257}
258
259void Serializer::processDebugInfo() {
260 if (!options.emitDebugInfo)
261 return;
262 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
263 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
264 fileID = getNextID();
265 SmallVector<uint32_t, 16> operands;
266 operands.push_back(fileID);
267 spirv::encodeStringLiteralInto(operands, fileName);
268 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
269 // TODO: Encode more debug instructions.
270}
271
272LogicalResult Serializer::processExtension() {
273 llvm::SmallVector<uint32_t, 16> extName;
274 llvm::SmallSet<Extension, 4> deducedExts(
275 llvm::from_range, module.getVceTriple()->getExtensions());
276 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
277 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
278 TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
279 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
280 return module.emitError(
281 "SPV_KHR_non_semantic_info extension not available");
282 deducedExts.insert(nonSemanticInfoExt);
283 }
284 for (spirv::Extension ext : deducedExts) {
285 extName.clear();
286 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
287 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
288 }
289 return success();
290}
291
292void Serializer::processMemoryModel() {
293 StringAttr memoryModelName = module.getMemoryModelAttrName();
294 auto mm = static_cast<uint32_t>(
295 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
296 .getValue());
297
298 StringAttr addressingModelName = module.getAddressingModelAttrName();
299 auto am = static_cast<uint32_t>(
300 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
301 .getValue());
302
303 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
304}
305
306static std::string getDecorationName(StringRef attrName) {
307 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
308 // expected FPFastMathMode.
309 if (attrName == "fp_fast_math_mode")
310 return "FPFastMathMode";
311 // similar here
312 if (attrName == "fp_rounding_mode")
313 return "FPRoundingMode";
314 // convertToCamelFromSnakeCase will not capitalize "INTEL".
315 if (attrName == "cache_control_load_intel")
316 return "CacheControlLoadINTEL";
317 if (attrName == "cache_control_store_intel")
318 return "CacheControlStoreINTEL";
319
320 return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
321}
322
323template <typename AttrTy, typename EmitF>
324static LogicalResult processDecorationList(Location loc, Decoration decoration,
325 Attribute attrList,
326 StringRef attrName, EmitF emitter) {
327 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
328 if (!arrayAttr) {
329 return emitError(loc, "expecting array attribute of ")
330 << attrName << " for " << stringifyDecoration(decoration);
331 }
332 if (arrayAttr.empty()) {
333 return emitError(loc, "expecting non-empty array attribute of ")
334 << attrName << " for " << stringifyDecoration(decoration);
335 }
336 for (Attribute attr : arrayAttr.getValue()) {
337 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
338 if (!cacheControlAttr) {
339 return emitError(loc, "expecting array attribute of ")
340 << attrName << " for " << stringifyDecoration(decoration);
341 }
342 // This named attribute encodes several decorations. Emit one per
343 // element in the array.
344 if (failed(emitter(cacheControlAttr)))
345 return failure();
346 }
347 return success();
348}
349
350LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
351 Decoration decoration,
352 Attribute attr) {
354 switch (decoration) {
355 case spirv::Decoration::LinkageAttributes: {
356 // Get the value of the Linkage Attributes
357 // e.g., LinkageAttributes=["linkageName", linkageType].
358 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
359 auto linkageName = linkageAttr.getLinkageName();
360 auto linkageType = linkageAttr.getLinkageType().getValue();
361 // Encode the Linkage Name (string literal to uint32_t).
362 spirv::encodeStringLiteralInto(args, linkageName);
363 // Encode LinkageType & Add the Linkagetype to the args.
364 args.push_back(static_cast<uint32_t>(linkageType));
365 break;
366 }
367 case spirv::Decoration::FPFastMathMode:
368 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
369 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
370 break;
371 }
372 return emitError(loc, "expected FPFastMathModeAttr attribute for ")
373 << stringifyDecoration(decoration);
374 case spirv::Decoration::FPRoundingMode:
375 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
376 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
377 break;
378 }
379 return emitError(loc, "expected FPRoundingModeAttr attribute for ")
380 << stringifyDecoration(decoration);
381 case spirv::Decoration::Binding:
382 case spirv::Decoration::DescriptorSet:
383 case spirv::Decoration::Location:
384 case spirv::Decoration::Index:
385 case spirv::Decoration::Offset:
386 case spirv::Decoration::XfbBuffer:
387 case spirv::Decoration::XfbStride:
388 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
389 args.push_back(intAttr.getValue().getZExtValue());
390 break;
391 }
392 return emitError(loc, "expected integer attribute for ")
393 << stringifyDecoration(decoration);
394 case spirv::Decoration::BuiltIn:
395 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
396 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
397 if (enumVal) {
398 args.push_back(static_cast<uint32_t>(*enumVal));
399 break;
400 }
401 return emitError(loc, "invalid ")
402 << stringifyDecoration(decoration) << " decoration attribute "
403 << strAttr.getValue();
404 }
405 return emitError(loc, "expected string attribute for ")
406 << stringifyDecoration(decoration);
407 case spirv::Decoration::Aliased:
408 case spirv::Decoration::AliasedPointer:
409 case spirv::Decoration::Flat:
410 case spirv::Decoration::NonReadable:
411 case spirv::Decoration::NonWritable:
412 case spirv::Decoration::NoPerspective:
413 case spirv::Decoration::NoSignedWrap:
414 case spirv::Decoration::NoUnsignedWrap:
415 case spirv::Decoration::RelaxedPrecision:
416 case spirv::Decoration::Restrict:
417 case spirv::Decoration::RestrictPointer:
418 case spirv::Decoration::NoContraction:
419 case spirv::Decoration::Constant:
420 case spirv::Decoration::Block:
421 case spirv::Decoration::Invariant:
422 case spirv::Decoration::Patch:
423 case spirv::Decoration::Coherent:
424 // For unit attributes and decoration attributes, the args list
425 // has no values so we do nothing.
426 if (isa<UnitAttr, DecorationAttr>(attr))
427 break;
428 return emitError(loc,
429 "expected unit attribute or decoration attribute for ")
430 << stringifyDecoration(decoration);
431 case spirv::Decoration::CacheControlLoadINTEL:
433 loc, decoration, attr, "CacheControlLoadINTEL",
434 [&](CacheControlLoadINTELAttr attr) {
435 unsigned cacheLevel = attr.getCacheLevel();
436 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
437 return emitDecoration(
438 resultID, decoration,
439 {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
440 });
441 case spirv::Decoration::CacheControlStoreINTEL:
443 loc, decoration, attr, "CacheControlStoreINTEL",
444 [&](CacheControlStoreINTELAttr attr) {
445 unsigned cacheLevel = attr.getCacheLevel();
446 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
447 return emitDecoration(
448 resultID, decoration,
449 {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
450 });
451 case spirv::Decoration::AlignmentId:
452 case spirv::Decoration::MaxByteOffsetId:
453 case spirv::Decoration::CounterBuffer: {
454 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
455 if (!symRef)
456 return emitError(loc, "expected symbol reference for ")
457 << stringifyDecoration(decoration);
458 StringRef symName = symRef.getValue();
459 uint32_t operandID = getVariableID(symName);
460 if (!operandID)
461 operandID = getSpecConstID(symName);
462 if (!operandID)
463 return emitError(loc, "could not find <id> for symbol '")
464 << symName << "' referenced by "
465 << stringifyDecoration(decoration);
466 return emitDecorationId(resultID, decoration, {operandID});
467 }
468 default:
469 return emitError(loc, "unhandled decoration ")
470 << stringifyDecoration(decoration);
471 }
472 return emitDecoration(resultID, decoration, args);
473}
474
475LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
476 NamedAttribute attr) {
477 StringRef attrName = attr.getName().strref();
478 std::string decorationName = getDecorationName(attrName);
479 std::optional<Decoration> decoration =
480 spirv::symbolizeDecoration(decorationName);
481 if (!decoration) {
482 return emitError(
483 loc, "non-argument attributes expected to have snake-case-ified "
484 "decoration name, unhandled attribute with name : ")
485 << attrName;
486 }
487 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
488}
489
490LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
491 assert(!name.empty() && "unexpected empty string for OpName");
492 if (!options.emitSymbolName)
493 return success();
494
495 SmallVector<uint32_t, 4> nameOperands;
496 nameOperands.push_back(resultID);
497 spirv::encodeStringLiteralInto(nameOperands, name);
498 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
499 return success();
500}
501
502template <>
503LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
504 Location loc, spirv::ArrayType type, uint32_t resultID) {
505 if (unsigned stride = type.getArrayStride()) {
506 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
507 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
508 }
509 return success();
510}
511
512template <>
513LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
514 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
515 if (unsigned stride = type.getArrayStride()) {
516 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
517 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
518 }
519 return success();
520}
521
522LogicalResult Serializer::processMemberDecoration(
523 uint32_t structID,
524 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
526 {structID, memberDecoration.memberIndex,
527 static_cast<uint32_t>(memberDecoration.decoration)});
528 if (memberDecoration.hasValue()) {
529 args.push_back(
530 cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
531 }
532 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
533 return success();
534}
535
536//===----------------------------------------------------------------------===//
537// Type
538//===----------------------------------------------------------------------===//
539
540// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
541// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
542// PushConstant Storage Classes must be explicitly laid out."
543bool Serializer::isInterfaceStructPtrType(Type type) const {
544 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
545 switch (ptrType.getStorageClass()) {
546 case spirv::StorageClass::PhysicalStorageBuffer:
547 case spirv::StorageClass::PushConstant:
548 case spirv::StorageClass::StorageBuffer:
549 case spirv::StorageClass::Uniform:
550 return isa<spirv::StructType>(ptrType.getPointeeType());
551 default:
552 break;
553 }
554 }
555 return false;
556}
557
558LogicalResult Serializer::processType(Location loc, Type type,
559 uint32_t &typeID) {
560 // Maintains a set of names for nested identified struct types. This is used
561 // to properly serialize recursive references.
562 SetVector<StringRef> serializationCtx;
563 return processTypeImpl(loc, type, typeID, serializationCtx);
564}
565
566LogicalResult
567Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
568 SetVector<StringRef> &serializationCtx) {
569
570 // Map unsigned integer types to singless integer types.
571 // This is needed otherwise the generated spirv assembly will contain
572 // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
573 // such module fails validation. Indeed at MLIR level the two types are
574 // different and lookup in the cache below misses.
575 // Note: This conversion needs to happen here before the type is looked up in
576 // the cache.
577 if (type.isUnsignedInteger()) {
578 type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
579 IntegerType::SignednessSemantics::Signless);
580 }
581
582 typeID = getTypeID(type);
583 if (typeID)
584 return success();
585
586 typeID = getNextID();
587 SmallVector<uint32_t, 4> operands;
588
589 operands.push_back(typeID);
590 auto typeEnum = spirv::Opcode::OpTypeVoid;
591 bool deferSerialization = false;
592
593 if ((isa<FunctionType>(type) &&
594 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
595 operands))) ||
596 (isa<GraphType>(type) &&
597 succeeded(
598 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
599 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
600 deferSerialization, serializationCtx))) {
601 if (deferSerialization)
602 return success();
603
604 typeIDMap[type] = typeID;
605
606 if (typeEnum == spirv::Opcode::OpTypeStruct)
607 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
608 operands);
609 else
610 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
611
612 if (recursiveStructInfos.count(type) != 0) {
613 // This recursive struct type is emitted already, now the OpTypePointer
614 // instructions referring to recursive references are emitted as well.
615 for (auto &ptrInfo : recursiveStructInfos[type]) {
616 // TODO: This might not work if more than 1 recursive reference is
617 // present in the struct.
618 SmallVector<uint32_t, 4> ptrOperands;
619 ptrOperands.push_back(ptrInfo.pointerTypeID);
620 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
621 ptrOperands.push_back(typeIDMap[type]);
622
623 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
624 ptrOperands);
625 }
626
627 recursiveStructInfos[type].clear();
628 }
629
630 return success();
631 }
632
633 return emitError(loc, "failed to process type: ") << type;
634}
635
636LogicalResult Serializer::prepareBasicType(
637 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
638 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
639 SetVector<StringRef> &serializationCtx) {
640 deferSerialization = false;
641
642 if (isVoidType(type)) {
643 typeEnum = spirv::Opcode::OpTypeVoid;
644 return success();
645 }
646
647 if (auto intType = dyn_cast<IntegerType>(type)) {
648 if (intType.getWidth() == 1) {
649 typeEnum = spirv::Opcode::OpTypeBool;
650 return success();
651 }
652
653 typeEnum = spirv::Opcode::OpTypeInt;
654 operands.push_back(intType.getWidth());
655 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
656 // to preserve or validate.
657 // 0 indicates unsigned, or no signedness semantics
658 // 1 indicates signed semantics."
659 operands.push_back(intType.isSigned() ? 1 : 0);
660 return success();
661 }
662
663 if (auto floatType = dyn_cast<FloatType>(type)) {
664 typeEnum = spirv::Opcode::OpTypeFloat;
665 operands.push_back(floatType.getWidth());
666 if (floatType.isBF16()) {
667 operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
668 }
669 if (floatType.isF8E4M3FN()) {
670 operands.push_back(
671 static_cast<uint32_t>(spirv::FPEncoding::Float8E4M3EXT));
672 }
673 if (floatType.isF8E5M2()) {
674 operands.push_back(
675 static_cast<uint32_t>(spirv::FPEncoding::Float8E5M2EXT));
676 }
677
678 return success();
679 }
680
681 if (auto vectorType = dyn_cast<VectorType>(type)) {
682 uint32_t elementTypeID = 0;
683 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
684 serializationCtx))) {
685 return failure();
686 }
687 typeEnum = spirv::Opcode::OpTypeVector;
688 operands.push_back(elementTypeID);
689 operands.push_back(vectorType.getNumElements());
690 return success();
691 }
692
693 if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
694 typeEnum = spirv::Opcode::OpTypeImage;
695 uint32_t sampledTypeID = 0;
696 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
697 return failure();
698
699 llvm::append_values(operands, sampledTypeID,
700 static_cast<uint32_t>(imageType.getDim()),
701 static_cast<uint32_t>(imageType.getDepthInfo()),
702 static_cast<uint32_t>(imageType.getArrayedInfo()),
703 static_cast<uint32_t>(imageType.getSamplingInfo()),
704 static_cast<uint32_t>(imageType.getSamplerUseInfo()),
705 static_cast<uint32_t>(imageType.getImageFormat()));
706 return success();
707 }
708
709 if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
710 typeEnum = spirv::Opcode::OpTypeArray;
711 uint32_t elementTypeID = 0;
712 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
713 serializationCtx))) {
714 return failure();
715 }
716 operands.push_back(elementTypeID);
717 if (auto elementCountID = prepareConstantInt(
718 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
719 operands.push_back(elementCountID);
720 }
721 return processTypeDecoration(loc, arrayType, resultID);
722 }
723
724 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
725 uint32_t pointeeTypeID = 0;
726 spirv::StructType pointeeStruct =
727 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
728
729 if (pointeeStruct && pointeeStruct.isIdentified() &&
730 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
731 // A recursive reference to an enclosing struct is found.
732 //
733 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
734 // class as operands.
735 SmallVector<uint32_t, 2> forwardPtrOperands;
736 forwardPtrOperands.push_back(resultID);
737 forwardPtrOperands.push_back(
738 static_cast<uint32_t>(ptrType.getStorageClass()));
739
740 encodeInstructionInto(typesGlobalValues,
741 spirv::Opcode::OpTypeForwardPointer,
742 forwardPtrOperands);
743
744 // 2. Find the pointee (enclosing) struct.
745 auto structType = spirv::StructType::getIdentified(
746 module.getContext(), pointeeStruct.getIdentifier());
747
748 if (!structType)
749 return failure();
750
751 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
752 // as deferred.
753 deferSerialization = true;
754
755 // 4. Record the info needed to emit the deferred OpTypePointer
756 // instruction when the enclosing struct is completely serialized.
757 recursiveStructInfos[structType].push_back(
758 {resultID, ptrType.getStorageClass()});
759 } else {
760 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
761 serializationCtx)))
762 return failure();
763 }
764
765 typeEnum = spirv::Opcode::OpTypePointer;
766 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
767 operands.push_back(pointeeTypeID);
768
769 // TODO: Now struct decorations are supported this code may not be
770 // necessary. However, it is left to support backwards compatibility.
771 // Ideally, Block decorations should be inserted when converting to SPIR-V.
772 if (isInterfaceStructPtrType(ptrType)) {
773 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
774 if (!structType.hasDecoration(spirv::Decoration::Block))
775 if (failed(emitDecoration(getTypeID(pointeeStruct),
776 spirv::Decoration::Block)))
777 return emitError(loc, "cannot decorate ")
778 << pointeeStruct << " with Block decoration";
779 }
780
781 return success();
782 }
783
784 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
785 uint32_t elementTypeID = 0;
786 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
787 elementTypeID, serializationCtx))) {
788 return failure();
789 }
790 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
791 operands.push_back(elementTypeID);
792 return processTypeDecoration(loc, runtimeArrayType, resultID);
793 }
794
795 if (isa<spirv::SamplerType>(type)) {
796 typeEnum = spirv::Opcode::OpTypeSampler;
797 return success();
798 }
799
800 if (isa<spirv::NamedBarrierType>(type)) {
801 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
802 return success();
803 }
804
805 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
806 typeEnum = spirv::Opcode::OpTypeSampledImage;
807 uint32_t imageTypeID = 0;
808 if (failed(
809 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
810 return failure();
811 }
812 operands.push_back(imageTypeID);
813 return success();
814 }
815
816 if (auto structType = dyn_cast<spirv::StructType>(type)) {
817 if (structType.isIdentified()) {
818 if (failed(processName(resultID, structType.getIdentifier())))
819 return failure();
820 serializationCtx.insert(structType.getIdentifier());
821 }
822
823 bool hasOffset = structType.hasOffset();
824 for (auto elementIndex :
825 llvm::seq<uint32_t>(0, structType.getNumElements())) {
826 uint32_t elementTypeID = 0;
827 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
828 elementTypeID, serializationCtx))) {
829 return failure();
830 }
831 operands.push_back(elementTypeID);
832 if (hasOffset) {
833 auto intType = IntegerType::get(structType.getContext(), 32);
834 // Decorate each struct member with an offset
835 spirv::StructType::MemberDecorationInfo offsetDecoration{
836 elementIndex, spirv::Decoration::Offset,
837 IntegerAttr::get(intType,
838 structType.getMemberOffset(elementIndex))};
839 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
840 return emitError(loc, "cannot decorate ")
841 << elementIndex << "-th member of " << structType
842 << " with its offset";
843 }
844 }
845 }
846 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
847 structType.getMemberDecorations(memberDecorations);
848
849 for (auto &memberDecoration : memberDecorations) {
850 if (failed(processMemberDecoration(resultID, memberDecoration))) {
851 return emitError(loc, "cannot decorate ")
852 << static_cast<uint32_t>(memberDecoration.memberIndex)
853 << "-th member of " << structType << " with "
854 << stringifyDecoration(memberDecoration.decoration);
855 }
856 }
857
858 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
859 structType.getStructDecorations(structDecorations);
860
861 for (spirv::StructType::StructDecorationInfo &structDecoration :
862 structDecorations) {
863 if (failed(processDecorationAttr(loc, resultID,
864 structDecoration.decoration,
865 structDecoration.decorationValue))) {
866 return emitError(loc, "cannot decorate struct ")
867 << structType << " with "
868 << stringifyDecoration(structDecoration.decoration);
869 }
870 }
871
872 typeEnum = spirv::Opcode::OpTypeStruct;
873
874 if (structType.isIdentified())
875 serializationCtx.remove(structType.getIdentifier());
876
877 return success();
878 }
879
880 if (auto cooperativeMatrixType =
881 dyn_cast<spirv::CooperativeMatrixType>(type)) {
882 uint32_t elementTypeID = 0;
883 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
884 elementTypeID, serializationCtx))) {
885 return failure();
886 }
887 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
888 auto getConstantOp = [&](uint32_t id) {
889 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
890 return prepareConstantInt(loc, attr);
891 };
892 llvm::append_values(
893 operands, elementTypeID,
894 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
895 getConstantOp(cooperativeMatrixType.getRows()),
896 getConstantOp(cooperativeMatrixType.getColumns()),
897 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
898 return success();
899 }
900
901 if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
902 uint32_t elementTypeID = 0;
903 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
904 serializationCtx))) {
905 return failure();
906 }
907 typeEnum = spirv::Opcode::OpTypeMatrix;
908 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
909 return success();
910 }
911
912 if (auto tensorArmType = dyn_cast<TensorArmType>(type)) {
913 uint32_t elementTypeID = 0;
914 uint32_t rank = 0;
915 uint32_t shapeID = 0;
916 uint32_t rankID = 0;
917 if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
918 elementTypeID, serializationCtx))) {
919 return failure();
920 }
921 if (tensorArmType.hasRank()) {
922 ArrayRef<int64_t> dims = tensorArmType.getShape();
923 rank = dims.size();
924 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
925 if (rankID == 0) {
926 return failure();
927 }
928
929 bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
930 if (rank > 0 && shaped) {
931 auto I32Type = IntegerType::get(type.getContext(), 32);
932 auto shapeType = ArrayType::get(I32Type, rank);
933 if (rank == 1) {
934 SmallVector<uint64_t, 1> index(rank);
935 shapeID = prepareDenseElementsConstant(
936 loc, shapeType,
937 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
938 index);
939 } else {
940 shapeID = prepareArrayConstant(
941 loc, shapeType,
942 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
943 }
944 if (shapeID == 0) {
945 return failure();
946 }
947 }
948 }
949 typeEnum = spirv::Opcode::OpTypeTensorARM;
950 operands.push_back(elementTypeID);
951 if (rankID == 0)
952 return success();
953 operands.push_back(rankID);
954 if (shapeID == 0)
955 return success();
956 operands.push_back(shapeID);
957 return success();
958 }
959
960 // TODO: Handle other types.
961 return emitError(loc, "unhandled type in serialization: ") << type;
962}
963
964LogicalResult
965Serializer::prepareFunctionType(Location loc, FunctionType type,
966 spirv::Opcode &typeEnum,
967 SmallVectorImpl<uint32_t> &operands) {
968 typeEnum = spirv::Opcode::OpTypeFunction;
969 assert(type.getNumResults() <= 1 &&
970 "serialization supports only a single return value");
971 uint32_t resultID = 0;
972 if (failed(processType(
973 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
974 resultID))) {
975 return failure();
976 }
977 operands.push_back(resultID);
978 for (auto &res : type.getInputs()) {
979 uint32_t argTypeID = 0;
980 if (failed(processType(loc, res, argTypeID))) {
981 return failure();
982 }
983 operands.push_back(argTypeID);
984 }
985 return success();
986}
987
988LogicalResult
989Serializer::prepareGraphType(Location loc, GraphType type,
990 spirv::Opcode &typeEnum,
991 SmallVectorImpl<uint32_t> &operands) {
992 typeEnum = spirv::Opcode::OpTypeGraphARM;
993 assert(type.getNumResults() >= 1 &&
994 "serialization requires at least a return value");
995
996 operands.push_back(type.getNumInputs());
997
998 for (Type argType : type.getInputs()) {
999 uint32_t argTypeID = 0;
1000 if (failed(processType(loc, argType, argTypeID)))
1001 return failure();
1002 operands.push_back(argTypeID);
1003 }
1004
1005 for (Type resType : type.getResults()) {
1006 uint32_t resTypeID = 0;
1007 if (failed(processType(loc, resType, resTypeID)))
1008 return failure();
1009 operands.push_back(resTypeID);
1010 }
1011
1012 return success();
1013}
1014
1015//===----------------------------------------------------------------------===//
1016// Constant
1017//===----------------------------------------------------------------------===//
1018
1019uint32_t Serializer::prepareConstant(Location loc, Type constType,
1020 Attribute valueAttr) {
1021 if (auto id = prepareConstantScalar(loc, valueAttr)) {
1022 return id;
1023 }
1024
1025 // This is a composite literal. We need to handle each component separately
1026 // and then emit an OpConstantComposite for the whole.
1027
1028 if (auto id = getConstantID(valueAttr)) {
1029 return id;
1030 }
1031
1032 uint32_t typeID = 0;
1033 if (failed(processType(loc, constType, typeID))) {
1034 return 0;
1035 }
1036
1037 uint32_t resultID = 0;
1038 if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1039 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1040 SmallVector<uint64_t, 4> index(rank);
1041 resultID = prepareDenseElementsConstant(loc, constType, attr,
1042 /*dim=*/0, index);
1043 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1044 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1045 }
1046
1047 if (resultID == 0) {
1048 emitError(loc, "cannot serialize attribute: ") << valueAttr;
1049 return 0;
1050 }
1051
1052 constIDMap[valueAttr] = resultID;
1053 return resultID;
1054}
1055
1056uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1057 ArrayAttr attr) {
1058 uint32_t typeID = 0;
1059 if (failed(processType(loc, constType, typeID))) {
1060 return 0;
1061 }
1062
1063 uint32_t resultID = getNextID();
1064 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1065 operands.reserve(attr.size() + 2);
1066 spirv::CompositeType compositeType = cast<spirv::CompositeType>(constType);
1067 for (auto [idx, elementAttr] : llvm::enumerate(attr)) {
1068 if (uint32_t elementID = prepareConstant(
1069 loc, compositeType.getElementType(idx), elementAttr)) {
1070 operands.push_back(elementID);
1071 } else {
1072 return 0;
1073 }
1074 }
1075 encodeInstructionWithContinuationInto(
1076 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1077
1078 return resultID;
1079}
1080
1081// TODO: Turn the below function into iterative function, instead of
1082// recursive function.
1083uint32_t
1084Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1085 DenseElementsAttr valueAttr, int dim,
1086 MutableArrayRef<uint64_t> index) {
1087 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
1088 assert(dim <= shapedType.getRank());
1089 if (shapedType.getRank() == dim) {
1090 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1091 return attr.getType().getElementType().isInteger(1)
1092 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1093 : prepareConstantInt(loc,
1094 attr.getValues<IntegerAttr>()[index]);
1095 }
1096 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1097 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1098 }
1099 return 0;
1100 }
1101
1102 uint32_t typeID = 0;
1103 if (failed(processType(loc, constType, typeID))) {
1104 return 0;
1105 }
1106
1107 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1108 uint32_t resultID = getNextID();
1109 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1110 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1111 if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1112 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1113 if (!innerShape.empty())
1114 elementType = spirv::TensorArmType::get(innerShape, elementType);
1115 }
1116
1117 // "If the Result Type is a cooperative matrix type, then there must be only
1118 // one Constituent, with scalar type matching the cooperative matrix Component
1119 // Type, and all components of the matrix are initialized to that value."
1120 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
1121 if (isa<spirv::CooperativeMatrixType>(constType)) {
1122 if (!valueAttr.isSplat()) {
1123 emitError(
1124 loc,
1125 "cannot serialize a non-splat value for a cooperative matrix type");
1126 return 0;
1127 }
1128 // numberOfConstituents is 1, so we only need one more elements in the
1129 // SmallVector, so the total is 3 (1 + 2).
1130 operands.reserve(3);
1131 // We set dim directly to `shapedType.getRank()` so the recursive call
1132 // directly returns the scalar type.
1133 if (auto elementID = prepareDenseElementsConstant(
1134 loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
1135 operands.push_back(elementID);
1136 } else {
1137 return 0;
1138 }
1139 } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
1140 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1141 {typeID, resultID});
1142 return resultID;
1143 } else {
1144 operands.reserve(numberOfConstituents + 2);
1145 for (int i = 0; i < numberOfConstituents; ++i) {
1146 index[dim] = i;
1147 if (auto elementID = prepareDenseElementsConstant(
1148 loc, elementType, valueAttr, dim + 1, index)) {
1149 operands.push_back(elementID);
1150 } else {
1151 return 0;
1152 }
1153 }
1154 }
1155 encodeInstructionWithContinuationInto(
1156 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1157
1158 return resultID;
1159}
1160
1161uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1162 bool isSpec) {
1163 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1164 return prepareConstantFp(loc, floatAttr, isSpec);
1165 }
1166 if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1167 return prepareConstantBool(loc, boolAttr, isSpec);
1168 }
1169 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1170 return prepareConstantInt(loc, intAttr, isSpec);
1171 }
1172
1173 return 0;
1174}
1175
1176uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1177 bool isSpec) {
1178 if (!isSpec) {
1179 // We can de-duplicate normal constants, but not specialization constants.
1180 if (auto id = getConstantID(boolAttr)) {
1181 return id;
1182 }
1183 }
1184
1185 // Process the type for this bool literal
1186 uint32_t typeID = 0;
1187 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
1188 return 0;
1189 }
1190
1191 auto resultID = getNextID();
1192 auto opcode = boolAttr.getValue()
1193 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1194 : spirv::Opcode::OpConstantTrue)
1195 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1196 : spirv::Opcode::OpConstantFalse);
1197 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1198
1199 if (!isSpec) {
1200 constIDMap[boolAttr] = resultID;
1201 }
1202 return resultID;
1203}
1204
1205uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1206 bool isSpec) {
1207 if (!isSpec) {
1208 // We can de-duplicate normal constants, but not specialization constants.
1209 if (auto id = getConstantID(intAttr)) {
1210 return id;
1211 }
1212 }
1213
1214 // Process the type for this integer literal
1215 uint32_t typeID = 0;
1216 if (failed(processType(loc, intAttr.getType(), typeID))) {
1217 return 0;
1218 }
1219
1220 auto resultID = getNextID();
1221 APInt value = intAttr.getValue();
1222 unsigned bitwidth = value.getBitWidth();
1223 bool isSigned = intAttr.getType().isSignedInteger();
1224 auto opcode =
1225 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1226
1227 switch (bitwidth) {
1228 // According to SPIR-V spec, "When the type's bit width is less than
1229 // 32-bits, the literal's value appears in the low-order bits of the word,
1230 // and the high-order bits must be 0 for a floating-point type, or 0 for an
1231 // integer type with Signedness of 0, or sign extended when Signedness
1232 // is 1."
1233 case 32:
1234 case 16:
1235 case 8: {
1236 uint32_t word = 0;
1237 if (isSigned) {
1238 word = static_cast<int32_t>(value.getSExtValue());
1239 } else {
1240 word = static_cast<uint32_t>(value.getZExtValue());
1241 }
1242 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1243 } break;
1244 // According to SPIR-V spec: "When the type's bit width is larger than one
1245 // word, the literal’s low-order words appear first."
1246 case 64: {
1247 struct DoubleWord {
1248 uint32_t word1;
1249 uint32_t word2;
1250 } words;
1251 if (isSigned) {
1252 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1253 } else {
1254 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1255 }
1256 encodeInstructionInto(typesGlobalValues, opcode,
1257 {typeID, resultID, words.word1, words.word2});
1258 } break;
1259 default: {
1260 std::string valueStr;
1261 llvm::raw_string_ostream rss(valueStr);
1262 value.print(rss, /*isSigned=*/false);
1263
1264 emitError(loc, "cannot serialize ")
1265 << bitwidth << "-bit integer literal: " << valueStr;
1266 return 0;
1267 }
1268 }
1269
1270 if (!isSpec) {
1271 constIDMap[intAttr] = resultID;
1272 }
1273 return resultID;
1274}
1275
1276uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1277 IntegerAttr intAttr) {
1278 // De-duplicate graph constants.
1279 if (uint32_t id = getGraphConstantARMId(intAttr)) {
1280 return id;
1281 }
1282
1283 // Process the type for this graph constant.
1284 uint32_t typeID = 0;
1285 if (failed(processType(loc, graphConstType, typeID))) {
1286 return 0;
1287 }
1288
1289 uint32_t resultID = getNextID();
1290 APInt value = intAttr.getValue();
1291 unsigned bitwidth = value.getBitWidth();
1292 if (bitwidth > 32) {
1293 emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
1294 << bitwidth << " bits";
1295 return 0;
1296 }
1297 bool isSigned = value.isSignedIntN(bitwidth);
1298
1299 uint32_t word = 0;
1300 if (isSigned) {
1301 word = static_cast<int32_t>(value.getSExtValue());
1302 } else {
1303 word = static_cast<uint32_t>(value.getZExtValue());
1304 }
1305 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
1306 {typeID, resultID, word});
1307 graphConstIDMap[intAttr] = resultID;
1308 return resultID;
1309}
1310
1311uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1312 bool isSpec) {
1313 if (!isSpec) {
1314 // We can de-duplicate normal constants, but not specialization constants.
1315 if (auto id = getConstantID(floatAttr)) {
1316 return id;
1317 }
1318 }
1319
1320 // Process the type for this float literal
1321 uint32_t typeID = 0;
1322 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1323 return 0;
1324 }
1325
1326 auto resultID = getNextID();
1327 APFloat value = floatAttr.getValue();
1328 const llvm::fltSemantics *semantics = &value.getSemantics();
1329
1330 auto opcode =
1331 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1332
1333 if (semantics == &APFloat::IEEEsingle()) {
1334 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1335 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1336 } else if (semantics == &APFloat::IEEEdouble()) {
1337 struct DoubleWord {
1338 uint32_t word1;
1339 uint32_t word2;
1340 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1341 encodeInstructionInto(typesGlobalValues, opcode,
1342 {typeID, resultID, words.word1, words.word2});
1343 } else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1344 &APFloat::Float8E4M3FN(),
1345 &APFloat::Float8E5M2()},
1346 semantics)) {
1347 uint32_t word =
1348 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1349 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1350 } else {
1351 std::string valueStr;
1352 llvm::raw_string_ostream rss(valueStr);
1353 value.print(rss);
1354
1355 emitError(loc, "cannot serialize ")
1356 << floatAttr.getType() << "-typed float literal: " << valueStr;
1357 return 0;
1358 }
1359
1360 if (!isSpec) {
1361 constIDMap[floatAttr] = resultID;
1362 }
1363 return resultID;
1364}
1365
1366// Returns type of attribute. In case of a TypedAttr this will simply return
1367// the type. But for an ArrayAttr which is untyped and can be multidimensional
1368// it creates the ArrayType recursively.
1370 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1371 return typedAttr.getType();
1372 }
1373
1374 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1375 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1376 }
1377
1378 return nullptr;
1379}
1380
1381uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1382 Type resultType,
1383 Attribute valueAttr) {
1384 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1385 if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1386 return id;
1387 }
1388
1389 uint32_t typeID = 0;
1390 if (failed(processType(loc, resultType, typeID))) {
1391 return 0;
1392 }
1393
1394 Type valueType = getValueType(valueAttr);
1395 if (!valueAttr)
1396 return 0;
1397
1398 auto compositeType = dyn_cast<CompositeType>(resultType);
1399 if (!compositeType)
1400 return 0;
1401 Type elementType = compositeType.getElementType(0);
1402
1403 uint32_t constandID;
1404 if (elementType == valueType) {
1405 constandID = prepareConstant(loc, elementType, valueAttr);
1406 } else {
1407 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1408 }
1409
1410 uint32_t resultID = getNextID();
1411 if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1412 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1413 {typeID, resultID});
1414 } else {
1415 encodeInstructionInto(typesGlobalValues,
1416 spirv::Opcode::OpConstantCompositeReplicateEXT,
1417 {typeID, resultID, constandID});
1418 }
1419
1420 constCompositeReplicateIDMap[valueTypePair] = resultID;
1421 return resultID;
1422}
1423
1424//===----------------------------------------------------------------------===//
1425// Control flow
1426//===----------------------------------------------------------------------===//
1427
1428uint32_t Serializer::getOrCreateBlockID(Block *block) {
1429 if (uint32_t id = getBlockID(block))
1430 return id;
1431 return blockIDMap[block] = getNextID();
1432}
1433
1434#ifndef NDEBUG
1435void Serializer::printBlock(Block *block, raw_ostream &os) {
1436 os << "block " << block << " (id = ";
1437 if (uint32_t id = getBlockID(block))
1438 os << id;
1439 else
1440 os << "unknown";
1441 os << ")\n";
1442}
1443#endif
1444
1445LogicalResult
1446Serializer::processBlock(Block *block, bool omitLabel,
1447 function_ref<LogicalResult()> emitMerge) {
1448 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1449 LLVM_DEBUG(block->print(llvm::dbgs()));
1450 LLVM_DEBUG(llvm::dbgs() << '\n');
1451 if (!omitLabel) {
1452 uint32_t blockID = getOrCreateBlockID(block);
1453 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1454
1455 // Emit OpLabel for this block.
1456 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1457 }
1458
1459 // Emit OpPhi instructions for block arguments, if any.
1460 if (failed(emitPhiForBlockArguments(block)))
1461 return failure();
1462
1463 // If we need to emit merge instructions, it must happen in this block. Check
1464 // whether we have other structured control flow ops, which will be expanded
1465 // into multiple basic blocks. If that's the case, we need to emit the merge
1466 // right now and then create new blocks for further serialization of the ops
1467 // in this block.
1468 if (emitMerge &&
1469 llvm::any_of(block->getOperations(),
1470 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1471 if (failed(emitMerge()))
1472 return failure();
1473 emitMerge = nullptr;
1474
1475 // Start a new block for further serialization.
1476 uint32_t blockID = getNextID();
1477 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1478 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1479 }
1480
1481 // Process each op in this block except the terminator.
1482 for (Operation &op : llvm::drop_end(*block)) {
1483 if (failed(processOperation(&op)))
1484 return failure();
1485 }
1486
1487 // Process the terminator.
1488 if (emitMerge)
1489 if (failed(emitMerge()))
1490 return failure();
1491 if (failed(processOperation(&block->back())))
1492 return failure();
1493
1494 return success();
1495}
1496
1497LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1498 // Nothing to do if this block has no arguments or it's the entry block, which
1499 // always has the same arguments as the function signature.
1500 if (block->args_empty() || block->isEntryBlock())
1501 return success();
1502
1503 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1504
1505 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1506 // A SPIR-V OpPhi instruction is of the syntax:
1507 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1508 // So we need to collect all predecessor blocks and the arguments they send
1509 // to this block.
1510 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1511 for (Block *mlirPredecessor : block->getPredecessors()) {
1512 auto *terminator = mlirPredecessor->getTerminator();
1513 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1514 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1515 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1516 // The predecessor here is the immediate one according to MLIR's IR
1517 // structure. It does not directly map to the incoming parent block for the
1518 // OpPhi instructions at SPIR-V binary level. This is because structured
1519 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1520 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1521 // the branch op jumping to the OpPhi's block then resides in the previous
1522 // structured control flow op's merge block.
1523 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1524 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1525 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1526 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1527 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1528 } else if (auto branchCondOp =
1529 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1530 std::optional<OperandRange> blockOperands;
1531 if (branchCondOp.getTrueTarget() == block) {
1532 blockOperands = branchCondOp.getTrueTargetOperands();
1533 } else {
1534 assert(branchCondOp.getFalseTarget() == block);
1535 blockOperands = branchCondOp.getFalseTargetOperands();
1536 }
1537 assert(!blockOperands->empty() &&
1538 "expected non-empty block operand range");
1539 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1540 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1541 std::optional<OperandRange> blockOperands;
1542 if (block == switchOp.getDefaultTarget()) {
1543 blockOperands = switchOp.getDefaultOperands();
1544 } else {
1545 SuccessorRange targets = switchOp.getTargets();
1546 auto it = llvm::find(targets, block);
1547 assert(it != targets.end());
1548 size_t index = std::distance(targets.begin(), it);
1549 blockOperands = switchOp.getTargetOperands(index);
1550 }
1551 assert(!blockOperands->empty() &&
1552 "expected non-empty block operand range");
1553 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1554 } else {
1555 return terminator->emitError("unimplemented terminator for Phi creation");
1556 }
1557 LLVM_DEBUG({
1558 llvm::dbgs() << " block arguments:\n";
1559 for (Value v : predecessors.back().second)
1560 llvm::dbgs() << " " << v << "\n";
1561 });
1562 }
1563
1564 // Then create OpPhi instruction for each of the block argument.
1565 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1566 BlockArgument arg = block->getArgument(argIndex);
1567
1568 // Get the type <id> and result <id> for this OpPhi instruction.
1569 uint32_t phiTypeID = 0;
1570 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1571 return failure();
1572 uint32_t phiID = getNextID();
1573
1574 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1575 << arg << " (id = " << phiID << ")\n");
1576
1577 // Prepare the (value <id>, parent block <id>) pairs.
1578 SmallVector<uint32_t, 8> phiArgs;
1579 phiArgs.push_back(phiTypeID);
1580 phiArgs.push_back(phiID);
1581
1582 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1583 Value value = predecessors[predIndex].second[argIndex];
1584 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1585 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1586 << ") value " << value << ' ');
1587 // Each pair is a value <id> ...
1588 uint32_t valueId = getValueID(value);
1589 if (valueId == 0) {
1590 // The op generating this value hasn't been visited yet so we don't have
1591 // an <id> assigned yet. Record this to fix up later.
1592 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1593 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1594 phiArgs.size());
1595 } else {
1596 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1597 }
1598 phiArgs.push_back(valueId);
1599 // ... and a parent block <id>.
1600 phiArgs.push_back(predBlockId);
1601 }
1602
1603 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1604 valueIDMap[arg] = phiID;
1605 }
1606
1607 return success();
1608}
1609
1610//===----------------------------------------------------------------------===//
1611// Operation
1612//===----------------------------------------------------------------------===//
1613
1614LogicalResult Serializer::encodeExtensionInstruction(
1615 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1616 ArrayRef<uint32_t> operands, SmallVectorImpl<uint32_t> &binary) {
1617 // Check if the extension has been imported.
1618 auto &setID = extendedInstSetIDMap[extensionSetName];
1619 if (!setID) {
1620 setID = getNextID();
1621 SmallVector<uint32_t, 16> importOperands;
1622 importOperands.push_back(setID);
1623 spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1624 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1625 importOperands);
1626 }
1627
1628 // The first two operands are the result type <id> and result <id>. The set
1629 // <id> and the opcode need to be insert after this.
1630 if (operands.size() < 2) {
1631 return op->emitError("extended instructions must have a result encoding");
1632 }
1633 SmallVector<uint32_t, 8> extInstOperands;
1634 extInstOperands.reserve(operands.size() + 2);
1635 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1636 extInstOperands.push_back(setID);
1637 extInstOperands.push_back(extensionOpcode);
1638 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1639 encodeInstructionInto(binary, spirv::Opcode::OpExtInst, extInstOperands);
1640 return success();
1641}
1642
1643LogicalResult Serializer::encodeExtensionInstruction(
1644 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1645 ArrayRef<uint32_t> operands) {
1646 if (failed(encodeExtensionInstruction(op, extensionSetName, extensionOpcode,
1647 operands, functionBody)))
1648 return failure();
1649
1650 if (extensionSetName == extTosa)
1651 updateTosaOpsMap(op);
1652
1653 return success();
1654}
1655
1656LogicalResult Serializer::processOperation(Operation *opInst) {
1657 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1658
1659 // First dispatch the ops that do not directly mirror an instruction from
1660 // the SPIR-V spec.
1662 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1663 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1664 .Case([&](spirv::BranchConditionalOp op) {
1665 return processBranchConditionalOp(op);
1666 })
1667 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1668 .Case([&](spirv::CompositeConstructOp op) {
1669 return processCompositeConstructOp(op);
1670 })
1671 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1672 return processConstantCompositeReplicateOp(op);
1673 })
1674 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1675 .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
1676 .Case([&](spirv::GraphEntryPointARMOp op) {
1677 return processGraphEntryPointARMOp(op);
1678 })
1679 .Case([&](spirv::GraphOutputsARMOp op) {
1680 return processGraphOutputsARMOp(op);
1681 })
1682 .Case([&](spirv::GlobalVariableOp op) {
1683 return processGlobalVariableOp(op);
1684 })
1685 .Case([&](spirv::GraphConstantARMOp op) {
1686 return processGraphConstantARMOp(op);
1687 })
1688 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1689 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1690 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1691 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1692 .Case([&](spirv::SpecConstantCompositeOp op) {
1693 return processSpecConstantCompositeOp(op);
1694 })
1695 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1696 return processSpecConstantCompositeReplicateOp(op);
1697 })
1698 .Case([&](spirv::SpecConstantOperationOp op) {
1699 return processSpecConstantOperationOp(op);
1700 })
1701 .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
1702 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1703 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1704
1705 // Then handle all the ops that directly mirror SPIR-V instructions with
1706 // auto-generated methods.
1707 .Default(
1708 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1709}
1710
1711LogicalResult
1712Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1713 Location loc = op.getLoc();
1714
1715 uint32_t resultTypeID = 0;
1716 if (failed(processType(loc, op.getType(), resultTypeID)))
1717 return failure();
1718
1719 uint32_t resultID = getNextID();
1720 valueIDMap[op.getResult()] = resultID;
1721
1722 SmallVector<uint32_t, 8> operands;
1723 operands.reserve(2 + op.getConstituents().size());
1724 operands.push_back(resultTypeID);
1725 operands.push_back(resultID);
1726 for (Value constituent : op.getConstituents()) {
1727 uint32_t id = getValueID(constituent);
1728 assert(id && "use before def!");
1729 operands.push_back(id);
1730 }
1731
1732 if (failed(emitDebugLine(functionBody, loc)))
1733 return failure();
1734
1735 encodeInstructionWithContinuationInto(
1736 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1737
1738 for (auto attr : op->getAttrs()) {
1739 if (failed(processDecoration(loc, resultID, attr)))
1740 return failure();
1741 }
1742
1743 return success();
1744}
1745
1746LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1747 StringRef extInstSet,
1748 uint32_t opcode) {
1749 SmallVector<uint32_t, 4> operands;
1750 Location loc = op->getLoc();
1751
1752 uint32_t resultID = 0;
1753 if (op->getNumResults() != 0) {
1754 uint32_t resultTypeID = 0;
1755 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1756 return failure();
1757 operands.push_back(resultTypeID);
1758
1759 resultID = getNextID();
1760 operands.push_back(resultID);
1761 valueIDMap[op->getResult(0)] = resultID;
1762 };
1763
1764 for (Value operand : op->getOperands())
1765 operands.push_back(getValueID(operand));
1766
1767 if (extInstSet != extTosa)
1768 // OpLine cannot be present in graphs
1769 if (failed(emitDebugLine(functionBody, loc)))
1770 return failure();
1771
1772 if (extInstSet.empty()) {
1773 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1774 operands);
1775 } else {
1776 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1777 return failure();
1778 }
1779
1780 if (op->getNumResults() != 0) {
1781 for (auto attr : op->getAttrs()) {
1782 if (failed(processDecoration(loc, resultID, attr)))
1783 return failure();
1784 }
1785 }
1786
1787 return success();
1788}
1789
1790void Serializer::updateTosaOpsMap(Operation *op) {
1791 if (!options.emitDebugInfo)
1792 return;
1793
1794 if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op->getParentOp())) {
1795 if (uint32_t graphID = getFunctionID(graphOp.getName()))
1796 tosaOpsMap[graphID][op->getLoc()].insert(op);
1797 }
1798}
1799
1800LogicalResult Serializer::emitDecoration(uint32_t target,
1801 spirv::Decoration decoration,
1802 ArrayRef<uint32_t> params) {
1803 uint32_t wordCount = 3 + params.size();
1804 llvm::append_values(
1805 decorations,
1806 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1807 static_cast<uint32_t>(decoration));
1808 llvm::append_range(decorations, params);
1809 return success();
1810}
1811
1812LogicalResult Serializer::emitDecorationId(uint32_t target,
1813 spirv::Decoration decoration,
1814 ArrayRef<uint32_t> operandIds) {
1815 uint32_t wordCount = 3 + operandIds.size();
1816 llvm::append_values(
1817 decorations,
1818 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorateId), target,
1819 static_cast<uint32_t>(decoration));
1820 llvm::append_range(decorations, operandIds);
1821 return success();
1822}
1823
1824LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1825 Location loc) {
1826 if (!options.emitDebugInfo)
1827 return success();
1828
1829 if (lastProcessedWasMergeInst) {
1830 lastProcessedWasMergeInst = false;
1831 return success();
1832 }
1833
1834 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1835 if (fileLoc)
1836 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1837 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1838 return success();
1839}
1840} // namespace spirv
1841} // namespace mlir
return success()
ArrayAttr()
b getContext())
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
OpListType & getOperations()
Definition Block.h:147
Operation & back()
Definition Block.h:162
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
Definition Block.h:146
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType(unsigned) const
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
constexpr llvm::StringLiteral extTosa
Extension set name for TOSA ops.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147