MLIR 23.0.0git
Serializer.cpp
Go to the documentation of this file.
1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73static bool isZeroValue(Attribute attr) {
74 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
76 }
77 if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
79 }
80 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
82 }
83 if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
84 return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
85 }
86 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
87 return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
88 }
89 return false;
90}
91
92/// Move all functions declaration before functions definitions. In SPIR-V
93/// "declarations" are functions without a body and "definitions" functions
94/// with a body. This is stronger than necessary. It should be sufficient to
95/// ensure any declarations precede their uses and not all definitions, however
96/// this allows to avoid analysing every function in the module this way.
97static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
98 Block::OpListType &ops = moduleOp.getBody()->getOperations();
99 if (ops.empty())
100 return;
101 Operation &firstOp = ops.front();
102 for (Operation &op : llvm::drop_begin(ops))
103 if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
105 funcOp->moveBefore(&firstOp);
106}
107
108namespace mlir {
109namespace spirv {
110
111/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
112/// the given `binary` vector.
114 ArrayRef<uint32_t> operands) {
115 uint32_t wordCount = 1 + operands.size();
116 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
117 binary.append(operands.begin(), operands.end());
118}
119
120Serializer::Serializer(spirv::ModuleOp module,
121 const SerializationOptions &options)
122 : module(module), mlirBuilder(module.getContext()), options(options) {}
123
124LogicalResult Serializer::serialize() {
125 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
126
127 if (failed(module.verifyInvariants()))
128 return failure();
129
130 // TODO: handle the other sections
131 processCapability();
132 if (failed(processExtension())) {
133 return failure();
134 }
135 processMemoryModel();
136 processDebugInfo();
137
139
140 // Iterate over the module body to serialize it. Assumptions are that there is
141 // only one basic block in the moduleOp
142 for (auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
144 return failure();
145 }
146 }
147
148 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
149 return success();
150}
151
153 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
158
159 binary.clear();
160 binary.reserve(moduleSize);
161
162 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
163 nextID);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176 binary.append(graphsDebugInfo.begin(), graphsDebugInfo.end());
177}
178
179#ifndef NDEBUG
181 os << "\n= Value <id> Map =\n\n";
182 for (auto valueIDPair : valueIDMap) {
183 Value val = valueIDPair.first;
184 os << " " << val << " "
185 << "id = " << valueIDPair.second << ' ';
186 if (auto *op = val.getDefiningOp()) {
187 os << "from op '" << op->getName() << "'";
188 } else if (auto arg = dyn_cast<BlockArgument>(val)) {
189 Block *block = arg.getOwner();
190 os << "from argument of block " << block << ' ';
191 os << " in op '" << block->getParentOp()->getName() << "'";
192 }
193 os << '\n';
194 }
195}
196#endif
197
198//===----------------------------------------------------------------------===//
199// Module structure
200//===----------------------------------------------------------------------===//
201
202uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
203 auto funcID = funcIDMap.lookup(fnName);
204 if (!funcID) {
205 funcID = getNextID();
206 funcIDMap[fnName] = funcID;
207 }
208 return funcID;
209}
210
211void Serializer::processCapability() {
212 for (auto cap : module.getVceTriple()->getCapabilities())
213 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
214 {static_cast<uint32_t>(cap)});
215}
216
217void Serializer::addLongCompositesCapability() {
218 if (longCompositesEmitted)
219 return;
220 longCompositesEmitted = true;
221 auto vceTriple = module.getVceTriple();
222 if (!llvm::is_contained(vceTriple->getCapabilities(),
223 spirv::Capability::LongCompositesINTEL))
225 capabilities, spirv::Opcode::OpCapability,
226 {static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL)});
227 if (!llvm::is_contained(vceTriple->getExtensions(),
228 spirv::Extension::SPV_INTEL_long_composites)) {
229 SmallVector<uint32_t, 8> extName;
231 extName,
232 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
233 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
234 }
235}
236
237void Serializer::encodeInstructionWithContinuationInto(
238 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
239 ArrayRef<uint32_t> operands) {
240 if (1 + operands.size() <= spirv::kMaxWordCount) {
241 encodeInstructionInto(binary, op, operands);
242 return;
243 }
244
245 std::optional<spirv::Opcode> continuationOp =
247 assert(continuationOp && "op is not a splittable composite/struct opcode");
248
249 const unsigned chunk = spirv::kMaxWordCount - 1;
250 encodeInstructionInto(binary, op, operands.take_front(chunk));
251 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
252 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
253 encodeInstructionInto(binary, *continuationOp, rest.take_front(chunk));
254 }
255
256 addLongCompositesCapability();
257}
258
259void Serializer::processDebugInfo() {
260 if (!options.emitDebugInfo)
261 return;
262 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
263 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
264 fileID = getNextID();
265 SmallVector<uint32_t, 16> operands;
266 operands.push_back(fileID);
267 spirv::encodeStringLiteralInto(operands, fileName);
268 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
269 // TODO: Encode more debug instructions.
270}
271
272LogicalResult Serializer::processExtension() {
273 llvm::SmallVector<uint32_t, 16> extName;
274 llvm::SmallSet<Extension, 4> deducedExts(
275 llvm::from_range, module.getVceTriple()->getExtensions());
276 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
277 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
278 TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
279 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
280 return module.emitError(
281 "SPV_KHR_non_semantic_info extension not available");
282 deducedExts.insert(nonSemanticInfoExt);
283 }
284 for (spirv::Extension ext : deducedExts) {
285 extName.clear();
286 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
287 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
288 }
289 return success();
290}
291
292void Serializer::processMemoryModel() {
293 StringAttr memoryModelName = module.getMemoryModelAttrName();
294 auto mm = static_cast<uint32_t>(
295 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
296 .getValue());
297
298 StringAttr addressingModelName = module.getAddressingModelAttrName();
299 auto am = static_cast<uint32_t>(
300 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
301 .getValue());
302
303 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
304}
305
306static std::string getDecorationName(StringRef attrName) {
307 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
308 // expected FPFastMathMode.
309 if (attrName == "fp_fast_math_mode")
310 return "FPFastMathMode";
311 // similar here
312 if (attrName == "fp_rounding_mode")
313 return "FPRoundingMode";
314 // convertToCamelFromSnakeCase will not capitalize "INTEL".
315 if (attrName == "cache_control_load_intel")
316 return "CacheControlLoadINTEL";
317 if (attrName == "cache_control_store_intel")
318 return "CacheControlStoreINTEL";
319
320 return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
321}
322
323template <typename AttrTy, typename EmitF>
324static LogicalResult processDecorationList(Location loc, Decoration decoration,
325 Attribute attrList,
326 StringRef attrName, EmitF emitter) {
327 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
328 if (!arrayAttr) {
329 return emitError(loc, "expecting array attribute of ")
330 << attrName << " for " << stringifyDecoration(decoration);
331 }
332 if (arrayAttr.empty()) {
333 return emitError(loc, "expecting non-empty array attribute of ")
334 << attrName << " for " << stringifyDecoration(decoration);
335 }
336 for (Attribute attr : arrayAttr.getValue()) {
337 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
338 if (!cacheControlAttr) {
339 return emitError(loc, "expecting array attribute of ")
340 << attrName << " for " << stringifyDecoration(decoration);
341 }
342 // This named attribute encodes several decorations. Emit one per
343 // element in the array.
344 if (failed(emitter(cacheControlAttr)))
345 return failure();
346 }
347 return success();
348}
349
350LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
351 Decoration decoration,
352 Attribute attr) {
354 switch (decoration) {
355 case spirv::Decoration::LinkageAttributes: {
356 // Get the value of the Linkage Attributes
357 // e.g., LinkageAttributes=["linkageName", linkageType].
358 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
359 auto linkageName = linkageAttr.getLinkageName();
360 auto linkageType = linkageAttr.getLinkageType().getValue();
361 // Encode the Linkage Name (string literal to uint32_t).
362 spirv::encodeStringLiteralInto(args, linkageName);
363 // Encode LinkageType & Add the Linkagetype to the args.
364 args.push_back(static_cast<uint32_t>(linkageType));
365 break;
366 }
367 case spirv::Decoration::FPFastMathMode:
368 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
369 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
370 break;
371 }
372 return emitError(loc, "expected FPFastMathModeAttr attribute for ")
373 << stringifyDecoration(decoration);
374 case spirv::Decoration::FPRoundingMode:
375 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
376 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
377 break;
378 }
379 return emitError(loc, "expected FPRoundingModeAttr attribute for ")
380 << stringifyDecoration(decoration);
381 case spirv::Decoration::Binding:
382 case spirv::Decoration::DescriptorSet:
383 case spirv::Decoration::Location:
384 case spirv::Decoration::Index:
385 case spirv::Decoration::Offset:
386 case spirv::Decoration::XfbBuffer:
387 case spirv::Decoration::XfbStride:
388 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
389 args.push_back(intAttr.getValue().getZExtValue());
390 break;
391 }
392 return emitError(loc, "expected integer attribute for ")
393 << stringifyDecoration(decoration);
394 case spirv::Decoration::BuiltIn:
395 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
396 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
397 if (enumVal) {
398 args.push_back(static_cast<uint32_t>(*enumVal));
399 break;
400 }
401 return emitError(loc, "invalid ")
402 << stringifyDecoration(decoration) << " decoration attribute "
403 << strAttr.getValue();
404 }
405 return emitError(loc, "expected string attribute for ")
406 << stringifyDecoration(decoration);
407 case spirv::Decoration::Aliased:
408 case spirv::Decoration::AliasedPointer:
409 case spirv::Decoration::Flat:
410 case spirv::Decoration::NonReadable:
411 case spirv::Decoration::NonWritable:
412 case spirv::Decoration::NoPerspective:
413 case spirv::Decoration::NoSignedWrap:
414 case spirv::Decoration::NoUnsignedWrap:
415 case spirv::Decoration::RelaxedPrecision:
416 case spirv::Decoration::Restrict:
417 case spirv::Decoration::RestrictPointer:
418 case spirv::Decoration::NoContraction:
419 case spirv::Decoration::Constant:
420 case spirv::Decoration::Block:
421 case spirv::Decoration::BufferBlock:
422 case spirv::Decoration::Invariant:
423 case spirv::Decoration::Patch:
424 case spirv::Decoration::Coherent:
425 // For unit attributes and decoration attributes, the args list
426 // has no values so we do nothing.
427 if (isa<UnitAttr, DecorationAttr>(attr))
428 break;
429 return emitError(loc,
430 "expected unit attribute or decoration attribute for ")
431 << stringifyDecoration(decoration);
432 case spirv::Decoration::CacheControlLoadINTEL:
434 loc, decoration, attr, "CacheControlLoadINTEL",
435 [&](CacheControlLoadINTELAttr attr) {
436 unsigned cacheLevel = attr.getCacheLevel();
437 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
438 return emitDecoration(
439 resultID, decoration,
440 {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
441 });
442 case spirv::Decoration::CacheControlStoreINTEL:
444 loc, decoration, attr, "CacheControlStoreINTEL",
445 [&](CacheControlStoreINTELAttr attr) {
446 unsigned cacheLevel = attr.getCacheLevel();
447 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
448 return emitDecoration(
449 resultID, decoration,
450 {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
451 });
452 case spirv::Decoration::AlignmentId:
453 case spirv::Decoration::MaxByteOffsetId:
454 case spirv::Decoration::CounterBuffer: {
455 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
456 if (!symRef)
457 return emitError(loc, "expected symbol reference for ")
458 << stringifyDecoration(decoration);
459 StringRef symName = symRef.getValue();
460 uint32_t operandID = getVariableID(symName);
461 if (!operandID)
462 operandID = getSpecConstID(symName);
463 if (!operandID)
464 return emitError(loc, "could not find <id> for symbol '")
465 << symName << "' referenced by "
466 << stringifyDecoration(decoration);
467 return emitDecorationId(resultID, decoration, {operandID});
468 }
469 default:
470 return emitError(loc, "unhandled decoration ")
471 << stringifyDecoration(decoration);
472 }
473 return emitDecoration(resultID, decoration, args);
474}
475
476LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
477 NamedAttribute attr) {
478 StringRef attrName = attr.getName().strref();
479 std::string decorationName = getDecorationName(attrName);
480 std::optional<Decoration> decoration =
481 spirv::symbolizeDecoration(decorationName);
482 if (!decoration) {
483 return emitError(
484 loc, "non-argument attributes expected to have snake-case-ified "
485 "decoration name, unhandled attribute with name : ")
486 << attrName;
487 }
488 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
489}
490
491LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
492 assert(!name.empty() && "unexpected empty string for OpName");
493 if (!options.emitSymbolName)
494 return success();
495
496 SmallVector<uint32_t, 4> nameOperands;
497 nameOperands.push_back(resultID);
498 spirv::encodeStringLiteralInto(nameOperands, name);
499 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
500 return success();
501}
502
503template <>
504LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
505 Location loc, spirv::ArrayType type, uint32_t resultID) {
506 if (unsigned stride = type.getArrayStride()) {
507 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
508 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
509 }
510 return success();
511}
512
513template <>
514LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
515 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
516 if (unsigned stride = type.getArrayStride()) {
517 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
518 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
519 }
520 return success();
521}
522
523LogicalResult Serializer::processMemberDecoration(
524 uint32_t structID,
525 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
527 {structID, memberDecoration.memberIndex,
528 static_cast<uint32_t>(memberDecoration.decoration)});
529 if (memberDecoration.hasValue()) {
530 args.push_back(
531 cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
532 }
533 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
534 return success();
535}
536
537//===----------------------------------------------------------------------===//
538// Type
539//===----------------------------------------------------------------------===//
540
541// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
542// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
543// PushConstant Storage Classes must be explicitly laid out."
544bool Serializer::isInterfaceStructPtrType(Type type) const {
545 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
546 switch (ptrType.getStorageClass()) {
547 case spirv::StorageClass::PhysicalStorageBuffer:
548 case spirv::StorageClass::PushConstant:
549 case spirv::StorageClass::StorageBuffer:
550 case spirv::StorageClass::Uniform:
551 return isa<spirv::StructType>(ptrType.getPointeeType());
552 default:
553 break;
554 }
555 }
556 return false;
557}
558
559LogicalResult Serializer::processType(Location loc, Type type,
560 uint32_t &typeID) {
561 // Maintains a set of names for nested identified struct types. This is used
562 // to properly serialize recursive references.
563 SetVector<StringRef> serializationCtx;
564 return processTypeImpl(loc, type, typeID, serializationCtx);
565}
566
567LogicalResult
568Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
569 SetVector<StringRef> &serializationCtx) {
570
571 // Map unsigned integer types to singless integer types.
572 // This is needed otherwise the generated spirv assembly will contain
573 // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
574 // such module fails validation. Indeed at MLIR level the two types are
575 // different and lookup in the cache below misses.
576 // Note: This conversion needs to happen here before the type is looked up in
577 // the cache.
578 if (type.isUnsignedInteger()) {
579 type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
580 IntegerType::SignednessSemantics::Signless);
581 }
582
583 typeID = getTypeID(type);
584 if (typeID)
585 return success();
586
587 typeID = getNextID();
588 SmallVector<uint32_t, 4> operands;
589
590 operands.push_back(typeID);
591 auto typeEnum = spirv::Opcode::OpTypeVoid;
592 bool deferSerialization = false;
593
594 if ((isa<FunctionType>(type) &&
595 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
596 operands))) ||
597 (isa<GraphType>(type) &&
598 succeeded(
599 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
600 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
601 deferSerialization, serializationCtx))) {
602 if (deferSerialization)
603 return success();
604
605 typeIDMap[type] = typeID;
606
607 if (typeEnum == spirv::Opcode::OpTypeStruct)
608 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
609 operands);
610 else
611 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
612
613 if (recursiveStructInfos.count(type) != 0) {
614 // This recursive struct type is emitted already, now the OpTypePointer
615 // instructions referring to recursive references are emitted as well.
616 for (auto &ptrInfo : recursiveStructInfos[type]) {
617 // TODO: This might not work if more than 1 recursive reference is
618 // present in the struct.
619 SmallVector<uint32_t, 4> ptrOperands;
620 ptrOperands.push_back(ptrInfo.pointerTypeID);
621 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
622 ptrOperands.push_back(typeIDMap[type]);
623
624 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
625 ptrOperands);
626 }
627
628 recursiveStructInfos[type].clear();
629 }
630
631 return success();
632 }
633
634 return emitError(loc, "failed to process type: ") << type;
635}
636
637LogicalResult Serializer::prepareBasicType(
638 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
639 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
640 SetVector<StringRef> &serializationCtx) {
641 deferSerialization = false;
642
643 if (isVoidType(type)) {
644 typeEnum = spirv::Opcode::OpTypeVoid;
645 return success();
646 }
647
648 if (auto intType = dyn_cast<IntegerType>(type)) {
649 if (intType.getWidth() == 1) {
650 typeEnum = spirv::Opcode::OpTypeBool;
651 return success();
652 }
653
654 typeEnum = spirv::Opcode::OpTypeInt;
655 operands.push_back(intType.getWidth());
656 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
657 // to preserve or validate.
658 // 0 indicates unsigned, or no signedness semantics
659 // 1 indicates signed semantics."
660 operands.push_back(intType.isSigned() ? 1 : 0);
661 return success();
662 }
663
664 if (auto floatType = dyn_cast<FloatType>(type)) {
665 typeEnum = spirv::Opcode::OpTypeFloat;
666 operands.push_back(floatType.getWidth());
667 if (floatType.isBF16()) {
668 operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
669 }
670 if (floatType.isF8E4M3FN()) {
671 operands.push_back(
672 static_cast<uint32_t>(spirv::FPEncoding::Float8E4M3EXT));
673 }
674 if (floatType.isF8E5M2()) {
675 operands.push_back(
676 static_cast<uint32_t>(spirv::FPEncoding::Float8E5M2EXT));
677 }
678
679 return success();
680 }
681
682 if (auto vectorType = dyn_cast<VectorType>(type)) {
683 uint32_t elementTypeID = 0;
684 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
685 serializationCtx))) {
686 return failure();
687 }
688 typeEnum = spirv::Opcode::OpTypeVector;
689 operands.push_back(elementTypeID);
690 operands.push_back(vectorType.getNumElements());
691 return success();
692 }
693
694 if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
695 typeEnum = spirv::Opcode::OpTypeImage;
696 uint32_t sampledTypeID = 0;
697 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
698 return failure();
699
700 llvm::append_values(operands, sampledTypeID,
701 static_cast<uint32_t>(imageType.getDim()),
702 static_cast<uint32_t>(imageType.getDepthInfo()),
703 static_cast<uint32_t>(imageType.getArrayedInfo()),
704 static_cast<uint32_t>(imageType.getSamplingInfo()),
705 static_cast<uint32_t>(imageType.getSamplerUseInfo()),
706 static_cast<uint32_t>(imageType.getImageFormat()));
707 return success();
708 }
709
710 if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
711 typeEnum = spirv::Opcode::OpTypeArray;
712 uint32_t elementTypeID = 0;
713 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
714 serializationCtx))) {
715 return failure();
716 }
717 operands.push_back(elementTypeID);
718 if (auto elementCountID = prepareConstantInt(
719 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
720 operands.push_back(elementCountID);
721 }
722 return processTypeDecoration(loc, arrayType, resultID);
723 }
724
725 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
726 uint32_t pointeeTypeID = 0;
727 spirv::StructType pointeeStruct =
728 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
729
730 if (pointeeStruct && pointeeStruct.isIdentified() &&
731 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
732 // A recursive reference to an enclosing struct is found.
733 //
734 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
735 // class as operands.
736 SmallVector<uint32_t, 2> forwardPtrOperands;
737 forwardPtrOperands.push_back(resultID);
738 forwardPtrOperands.push_back(
739 static_cast<uint32_t>(ptrType.getStorageClass()));
740
741 encodeInstructionInto(typesGlobalValues,
742 spirv::Opcode::OpTypeForwardPointer,
743 forwardPtrOperands);
744
745 // 2. Find the pointee (enclosing) struct.
746 auto structType = spirv::StructType::getIdentified(
747 module.getContext(), pointeeStruct.getIdentifier());
748
749 if (!structType)
750 return failure();
751
752 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
753 // as deferred.
754 deferSerialization = true;
755
756 // 4. Record the info needed to emit the deferred OpTypePointer
757 // instruction when the enclosing struct is completely serialized.
758 recursiveStructInfos[structType].push_back(
759 {resultID, ptrType.getStorageClass()});
760 } else {
761 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
762 serializationCtx)))
763 return failure();
764 }
765
766 typeEnum = spirv::Opcode::OpTypePointer;
767 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
768 operands.push_back(pointeeTypeID);
769
770 // TODO: Now struct decorations are supported this code may not be
771 // necessary. However, it is left to support backwards compatibility.
772 // Ideally, Block decorations should be inserted when converting to SPIR-V.
773 if (isInterfaceStructPtrType(ptrType)) {
774 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
775 if (!structType.hasDecoration(spirv::Decoration::Block) &&
776 !structType.hasDecoration(spirv::Decoration::BufferBlock))
777 if (failed(emitDecoration(getTypeID(pointeeStruct),
778 spirv::Decoration::Block)))
779 return emitError(loc, "cannot decorate ")
780 << pointeeStruct << " with Block decoration";
781 }
782
783 return success();
784 }
785
786 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
787 uint32_t elementTypeID = 0;
788 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
789 elementTypeID, serializationCtx))) {
790 return failure();
791 }
792 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
793 operands.push_back(elementTypeID);
794 return processTypeDecoration(loc, runtimeArrayType, resultID);
795 }
796
797 if (isa<spirv::SamplerType>(type)) {
798 typeEnum = spirv::Opcode::OpTypeSampler;
799 return success();
800 }
801
802 if (isa<spirv::NamedBarrierType>(type)) {
803 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
804 return success();
805 }
806
807 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
808 typeEnum = spirv::Opcode::OpTypeSampledImage;
809 uint32_t imageTypeID = 0;
810 if (failed(
811 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
812 return failure();
813 }
814 operands.push_back(imageTypeID);
815 return success();
816 }
817
818 if (auto structType = dyn_cast<spirv::StructType>(type)) {
819 if (structType.isIdentified()) {
820 if (failed(processName(resultID, structType.getIdentifier())))
821 return failure();
822 serializationCtx.insert(structType.getIdentifier());
823 }
824
825 bool hasOffset = structType.hasOffset();
826 for (auto elementIndex :
827 llvm::seq<uint32_t>(0, structType.getNumElements())) {
828 uint32_t elementTypeID = 0;
829 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
830 elementTypeID, serializationCtx))) {
831 return failure();
832 }
833 operands.push_back(elementTypeID);
834 if (hasOffset) {
835 auto intType = IntegerType::get(structType.getContext(), 32);
836 // Decorate each struct member with an offset
837 spirv::StructType::MemberDecorationInfo offsetDecoration{
838 elementIndex, spirv::Decoration::Offset,
839 IntegerAttr::get(intType,
840 structType.getMemberOffset(elementIndex))};
841 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
842 return emitError(loc, "cannot decorate ")
843 << elementIndex << "-th member of " << structType
844 << " with its offset";
845 }
846 }
847 }
848 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
849 structType.getMemberDecorations(memberDecorations);
850
851 for (auto &memberDecoration : memberDecorations) {
852 if (failed(processMemberDecoration(resultID, memberDecoration))) {
853 return emitError(loc, "cannot decorate ")
854 << static_cast<uint32_t>(memberDecoration.memberIndex)
855 << "-th member of " << structType << " with "
856 << stringifyDecoration(memberDecoration.decoration);
857 }
858 }
859
860 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
861 structType.getStructDecorations(structDecorations);
862
863 for (spirv::StructType::StructDecorationInfo &structDecoration :
864 structDecorations) {
865 if (failed(processDecorationAttr(loc, resultID,
866 structDecoration.decoration,
867 structDecoration.decorationValue))) {
868 return emitError(loc, "cannot decorate struct ")
869 << structType << " with "
870 << stringifyDecoration(structDecoration.decoration);
871 }
872 }
873
874 typeEnum = spirv::Opcode::OpTypeStruct;
875
876 if (structType.isIdentified())
877 serializationCtx.remove(structType.getIdentifier());
878
879 return success();
880 }
881
882 if (auto cooperativeMatrixType =
883 dyn_cast<spirv::CooperativeMatrixType>(type)) {
884 uint32_t elementTypeID = 0;
885 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
886 elementTypeID, serializationCtx))) {
887 return failure();
888 }
889 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
890 auto getConstantOp = [&](uint32_t id) {
891 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
892 return prepareConstantInt(loc, attr);
893 };
894 llvm::append_values(
895 operands, elementTypeID,
896 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
897 getConstantOp(cooperativeMatrixType.getRows()),
898 getConstantOp(cooperativeMatrixType.getColumns()),
899 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
900 return success();
901 }
902
903 if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
904 uint32_t elementTypeID = 0;
905 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
906 serializationCtx))) {
907 return failure();
908 }
909 typeEnum = spirv::Opcode::OpTypeMatrix;
910 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
911 return success();
912 }
913
914 if (auto tensorArmType = dyn_cast<TensorArmType>(type)) {
915 uint32_t elementTypeID = 0;
916 uint32_t rank = 0;
917 uint32_t shapeID = 0;
918 uint32_t rankID = 0;
919 if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
920 elementTypeID, serializationCtx))) {
921 return failure();
922 }
923 if (tensorArmType.hasRank()) {
924 ArrayRef<int64_t> dims = tensorArmType.getShape();
925 rank = dims.size();
926 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
927 if (rankID == 0) {
928 return failure();
929 }
930
931 bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
932 if (rank > 0 && shaped) {
933 auto I32Type = IntegerType::get(type.getContext(), 32);
934 auto shapeType = ArrayType::get(I32Type, rank);
935 if (rank == 1) {
936 SmallVector<uint64_t, 1> index(rank);
937 shapeID = prepareDenseElementsConstant(
938 loc, shapeType,
939 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
940 index);
941 } else {
942 shapeID = prepareArrayConstant(
943 loc, shapeType,
944 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
945 }
946 if (shapeID == 0) {
947 return failure();
948 }
949 }
950 }
951 typeEnum = spirv::Opcode::OpTypeTensorARM;
952 operands.push_back(elementTypeID);
953 if (rankID == 0)
954 return success();
955 operands.push_back(rankID);
956 if (shapeID == 0)
957 return success();
958 operands.push_back(shapeID);
959 return success();
960 }
961
962 // TODO: Handle other types.
963 return emitError(loc, "unhandled type in serialization: ") << type;
964}
965
966LogicalResult
967Serializer::prepareFunctionType(Location loc, FunctionType type,
968 spirv::Opcode &typeEnum,
969 SmallVectorImpl<uint32_t> &operands) {
970 typeEnum = spirv::Opcode::OpTypeFunction;
971 assert(type.getNumResults() <= 1 &&
972 "serialization supports only a single return value");
973 uint32_t resultID = 0;
974 if (failed(processType(
975 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
976 resultID))) {
977 return failure();
978 }
979 operands.push_back(resultID);
980 for (auto &res : type.getInputs()) {
981 uint32_t argTypeID = 0;
982 if (failed(processType(loc, res, argTypeID))) {
983 return failure();
984 }
985 operands.push_back(argTypeID);
986 }
987 return success();
988}
989
990LogicalResult
991Serializer::prepareGraphType(Location loc, GraphType type,
992 spirv::Opcode &typeEnum,
993 SmallVectorImpl<uint32_t> &operands) {
994 typeEnum = spirv::Opcode::OpTypeGraphARM;
995 assert(type.getNumResults() >= 1 &&
996 "serialization requires at least a return value");
997
998 operands.push_back(type.getNumInputs());
999
1000 for (Type argType : type.getInputs()) {
1001 uint32_t argTypeID = 0;
1002 if (failed(processType(loc, argType, argTypeID)))
1003 return failure();
1004 operands.push_back(argTypeID);
1005 }
1006
1007 for (Type resType : type.getResults()) {
1008 uint32_t resTypeID = 0;
1009 if (failed(processType(loc, resType, resTypeID)))
1010 return failure();
1011 operands.push_back(resTypeID);
1012 }
1013
1014 return success();
1015}
1016
1017//===----------------------------------------------------------------------===//
1018// Constant
1019//===----------------------------------------------------------------------===//
1020
1021uint32_t Serializer::prepareConstant(Location loc, Type constType,
1022 Attribute valueAttr) {
1023 if (auto id = prepareConstantScalar(loc, valueAttr)) {
1024 return id;
1025 }
1026
1027 // This is a composite literal. We need to handle each component separately
1028 // and then emit an OpConstantComposite for the whole.
1029
1030 if (auto id = getConstantID(valueAttr)) {
1031 return id;
1032 }
1033
1034 uint32_t typeID = 0;
1035 if (failed(processType(loc, constType, typeID))) {
1036 return 0;
1037 }
1038
1039 uint32_t resultID = 0;
1040 if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1041 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1042 SmallVector<uint64_t, 4> index(rank);
1043 resultID = prepareDenseElementsConstant(loc, constType, attr,
1044 /*dim=*/0, index);
1045 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1046 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1047 }
1048
1049 if (resultID == 0) {
1050 emitError(loc, "cannot serialize attribute: ") << valueAttr;
1051 return 0;
1052 }
1053
1054 constIDMap[valueAttr] = resultID;
1055 return resultID;
1056}
1057
1058uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1059 ArrayAttr attr) {
1060 uint32_t typeID = 0;
1061 if (failed(processType(loc, constType, typeID))) {
1062 return 0;
1063 }
1064
1065 uint32_t resultID = getNextID();
1066 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1067 operands.reserve(attr.size() + 2);
1068 spirv::CompositeType compositeType = cast<spirv::CompositeType>(constType);
1069 for (auto [idx, elementAttr] : llvm::enumerate(attr)) {
1070 if (uint32_t elementID = prepareConstant(
1071 loc, compositeType.getElementType(idx), elementAttr)) {
1072 operands.push_back(elementID);
1073 } else {
1074 return 0;
1075 }
1076 }
1077 encodeInstructionWithContinuationInto(
1078 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1079
1080 return resultID;
1081}
1082
1083// TODO: Turn the below function into iterative function, instead of
1084// recursive function.
1085uint32_t
1086Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1087 DenseElementsAttr valueAttr, int dim,
1088 MutableArrayRef<uint64_t> index) {
1089 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
1090 assert(dim <= shapedType.getRank());
1091 if (shapedType.getRank() == dim) {
1092 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1093 return attr.getType().getElementType().isInteger(1)
1094 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1095 : prepareConstantInt(loc,
1096 attr.getValues<IntegerAttr>()[index]);
1097 }
1098 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1099 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1100 }
1101 return 0;
1102 }
1103
1104 uint32_t typeID = 0;
1105 if (failed(processType(loc, constType, typeID))) {
1106 return 0;
1107 }
1108
1109 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1110 uint32_t resultID = getNextID();
1111 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1112 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1113 if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1114 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1115 if (!innerShape.empty())
1116 elementType = spirv::TensorArmType::get(innerShape, elementType);
1117 }
1118
1119 // "If the Result Type is a cooperative matrix type, then there must be only
1120 // one Constituent, with scalar type matching the cooperative matrix Component
1121 // Type, and all components of the matrix are initialized to that value."
1122 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
1123 if (isa<spirv::CooperativeMatrixType>(constType)) {
1124 if (!valueAttr.isSplat()) {
1125 emitError(
1126 loc,
1127 "cannot serialize a non-splat value for a cooperative matrix type");
1128 return 0;
1129 }
1130 // numberOfConstituents is 1, so we only need one more elements in the
1131 // SmallVector, so the total is 3 (1 + 2).
1132 operands.reserve(3);
1133 // We set dim directly to `shapedType.getRank()` so the recursive call
1134 // directly returns the scalar type.
1135 if (auto elementID = prepareDenseElementsConstant(
1136 loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
1137 operands.push_back(elementID);
1138 } else {
1139 return 0;
1140 }
1141 } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
1142 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1143 {typeID, resultID});
1144 return resultID;
1145 } else {
1146 operands.reserve(numberOfConstituents + 2);
1147 for (int i = 0; i < numberOfConstituents; ++i) {
1148 index[dim] = i;
1149 if (auto elementID = prepareDenseElementsConstant(
1150 loc, elementType, valueAttr, dim + 1, index)) {
1151 operands.push_back(elementID);
1152 } else {
1153 return 0;
1154 }
1155 }
1156 }
1157 encodeInstructionWithContinuationInto(
1158 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1159
1160 return resultID;
1161}
1162
1163uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1164 bool isSpec) {
1165 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1166 return prepareConstantFp(loc, floatAttr, isSpec);
1167 }
1168 if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1169 return prepareConstantBool(loc, boolAttr, isSpec);
1170 }
1171 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1172 return prepareConstantInt(loc, intAttr, isSpec);
1173 }
1174
1175 return 0;
1176}
1177
1178uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1179 bool isSpec) {
1180 if (!isSpec) {
1181 // We can de-duplicate normal constants, but not specialization constants.
1182 if (auto id = getConstantID(boolAttr)) {
1183 return id;
1184 }
1185 }
1186
1187 // Process the type for this bool literal
1188 uint32_t typeID = 0;
1189 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
1190 return 0;
1191 }
1192
1193 auto resultID = getNextID();
1194 auto opcode = boolAttr.getValue()
1195 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1196 : spirv::Opcode::OpConstantTrue)
1197 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1198 : spirv::Opcode::OpConstantFalse);
1199 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1200
1201 if (!isSpec) {
1202 constIDMap[boolAttr] = resultID;
1203 }
1204 return resultID;
1205}
1206
1207uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1208 bool isSpec) {
1209 if (!isSpec) {
1210 // We can de-duplicate normal constants, but not specialization constants.
1211 if (auto id = getConstantID(intAttr)) {
1212 return id;
1213 }
1214 }
1215
1216 // Process the type for this integer literal
1217 uint32_t typeID = 0;
1218 if (failed(processType(loc, intAttr.getType(), typeID))) {
1219 return 0;
1220 }
1221
1222 auto resultID = getNextID();
1223 APInt value = intAttr.getValue();
1224 unsigned bitwidth = value.getBitWidth();
1225 bool isSigned = intAttr.getType().isSignedInteger();
1226 auto opcode =
1227 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1228
1229 switch (bitwidth) {
1230 // According to SPIR-V spec, "When the type's bit width is less than
1231 // 32-bits, the literal's value appears in the low-order bits of the word,
1232 // and the high-order bits must be 0 for a floating-point type, or 0 for an
1233 // integer type with Signedness of 0, or sign extended when Signedness
1234 // is 1."
1235 case 32:
1236 case 16:
1237 case 8: {
1238 uint32_t word = 0;
1239 if (isSigned) {
1240 word = static_cast<int32_t>(value.getSExtValue());
1241 } else {
1242 word = static_cast<uint32_t>(value.getZExtValue());
1243 }
1244 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1245 } break;
1246 // According to SPIR-V spec: "When the type's bit width is larger than one
1247 // word, the literal’s low-order words appear first."
1248 case 64: {
1249 struct DoubleWord {
1250 uint32_t word1;
1251 uint32_t word2;
1252 } words;
1253 if (isSigned) {
1254 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1255 } else {
1256 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1257 }
1258 encodeInstructionInto(typesGlobalValues, opcode,
1259 {typeID, resultID, words.word1, words.word2});
1260 } break;
1261 default: {
1262 std::string valueStr;
1263 llvm::raw_string_ostream rss(valueStr);
1264 value.print(rss, /*isSigned=*/false);
1265
1266 emitError(loc, "cannot serialize ")
1267 << bitwidth << "-bit integer literal: " << valueStr;
1268 return 0;
1269 }
1270 }
1271
1272 if (!isSpec) {
1273 constIDMap[intAttr] = resultID;
1274 }
1275 return resultID;
1276}
1277
1278uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1279 IntegerAttr intAttr) {
1280 // De-duplicate graph constants.
1281 if (uint32_t id = getGraphConstantARMId(intAttr)) {
1282 return id;
1283 }
1284
1285 // Process the type for this graph constant.
1286 uint32_t typeID = 0;
1287 if (failed(processType(loc, graphConstType, typeID))) {
1288 return 0;
1289 }
1290
1291 uint32_t resultID = getNextID();
1292 APInt value = intAttr.getValue();
1293 unsigned bitwidth = value.getBitWidth();
1294 if (bitwidth > 32) {
1295 emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
1296 << bitwidth << " bits";
1297 return 0;
1298 }
1299 bool isSigned = value.isSignedIntN(bitwidth);
1300
1301 uint32_t word = 0;
1302 if (isSigned) {
1303 word = static_cast<int32_t>(value.getSExtValue());
1304 } else {
1305 word = static_cast<uint32_t>(value.getZExtValue());
1306 }
1307 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
1308 {typeID, resultID, word});
1309 graphConstIDMap[intAttr] = resultID;
1310 return resultID;
1311}
1312
1313uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1314 bool isSpec) {
1315 if (!isSpec) {
1316 // We can de-duplicate normal constants, but not specialization constants.
1317 if (auto id = getConstantID(floatAttr)) {
1318 return id;
1319 }
1320 }
1321
1322 // Process the type for this float literal
1323 uint32_t typeID = 0;
1324 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1325 return 0;
1326 }
1327
1328 auto resultID = getNextID();
1329 APFloat value = floatAttr.getValue();
1330 const llvm::fltSemantics *semantics = &value.getSemantics();
1331
1332 auto opcode =
1333 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1334
1335 if (semantics == &APFloat::IEEEsingle()) {
1336 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1337 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1338 } else if (semantics == &APFloat::IEEEdouble()) {
1339 struct DoubleWord {
1340 uint32_t word1;
1341 uint32_t word2;
1342 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1343 encodeInstructionInto(typesGlobalValues, opcode,
1344 {typeID, resultID, words.word1, words.word2});
1345 } else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1346 &APFloat::Float8E4M3FN(),
1347 &APFloat::Float8E5M2()},
1348 semantics)) {
1349 uint32_t word =
1350 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1351 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1352 } else {
1353 std::string valueStr;
1354 llvm::raw_string_ostream rss(valueStr);
1355 value.print(rss);
1356
1357 emitError(loc, "cannot serialize ")
1358 << floatAttr.getType() << "-typed float literal: " << valueStr;
1359 return 0;
1360 }
1361
1362 if (!isSpec) {
1363 constIDMap[floatAttr] = resultID;
1364 }
1365 return resultID;
1366}
1367
1368// Returns type of attribute. In case of a TypedAttr this will simply return
1369// the type. But for an ArrayAttr which is untyped and can be multidimensional
1370// it creates the ArrayType recursively.
1372 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1373 return typedAttr.getType();
1374 }
1375
1376 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1377 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1378 }
1379
1380 return nullptr;
1381}
1382
1383uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1384 Type resultType,
1385 Attribute valueAttr) {
1386 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1387 if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1388 return id;
1389 }
1390
1391 uint32_t typeID = 0;
1392 if (failed(processType(loc, resultType, typeID))) {
1393 return 0;
1394 }
1395
1396 Type valueType = getValueType(valueAttr);
1397 if (!valueAttr)
1398 return 0;
1399
1400 auto compositeType = dyn_cast<CompositeType>(resultType);
1401 if (!compositeType)
1402 return 0;
1403 Type elementType = compositeType.getElementType(0);
1404
1405 uint32_t constandID;
1406 if (elementType == valueType) {
1407 constandID = prepareConstant(loc, elementType, valueAttr);
1408 } else {
1409 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1410 }
1411
1412 uint32_t resultID = getNextID();
1413 if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1414 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1415 {typeID, resultID});
1416 } else {
1417 encodeInstructionInto(typesGlobalValues,
1418 spirv::Opcode::OpConstantCompositeReplicateEXT,
1419 {typeID, resultID, constandID});
1420 }
1421
1422 constCompositeReplicateIDMap[valueTypePair] = resultID;
1423 return resultID;
1424}
1425
1426//===----------------------------------------------------------------------===//
1427// Control flow
1428//===----------------------------------------------------------------------===//
1429
1430uint32_t Serializer::getOrCreateBlockID(Block *block) {
1431 if (uint32_t id = getBlockID(block))
1432 return id;
1433 return blockIDMap[block] = getNextID();
1434}
1435
1436#ifndef NDEBUG
1437void Serializer::printBlock(Block *block, raw_ostream &os) {
1438 os << "block " << block << " (id = ";
1439 if (uint32_t id = getBlockID(block))
1440 os << id;
1441 else
1442 os << "unknown";
1443 os << ")\n";
1444}
1445#endif
1446
1447LogicalResult
1448Serializer::processBlock(Block *block, bool omitLabel,
1449 function_ref<LogicalResult()> emitMerge) {
1450 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1451 LLVM_DEBUG(block->print(llvm::dbgs()));
1452 LLVM_DEBUG(llvm::dbgs() << '\n');
1453 if (!omitLabel) {
1454 uint32_t blockID = getOrCreateBlockID(block);
1455 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1456
1457 // Emit OpLabel for this block.
1458 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1459 }
1460
1461 // Emit OpPhi instructions for block arguments, if any.
1462 if (failed(emitPhiForBlockArguments(block)))
1463 return failure();
1464
1465 // If we need to emit merge instructions, it must happen in this block. Check
1466 // whether we have other structured control flow ops, which will be expanded
1467 // into multiple basic blocks. If that's the case, we need to emit the merge
1468 // right now and then create new blocks for further serialization of the ops
1469 // in this block.
1470 if (emitMerge &&
1471 llvm::any_of(block->getOperations(),
1472 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1473 if (failed(emitMerge()))
1474 return failure();
1475 emitMerge = nullptr;
1476
1477 // Start a new block for further serialization.
1478 uint32_t blockID = getNextID();
1479 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1480 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1481 }
1482
1483 // Process each op in this block except the terminator.
1484 for (Operation &op : llvm::drop_end(*block)) {
1485 if (failed(processOperation(&op)))
1486 return failure();
1487 }
1488
1489 // Process the terminator.
1490 if (emitMerge)
1491 if (failed(emitMerge()))
1492 return failure();
1493 if (failed(processOperation(&block->back())))
1494 return failure();
1495
1496 return success();
1497}
1498
1499LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1500 // Nothing to do if this block has no arguments or it's the entry block, which
1501 // always has the same arguments as the function signature.
1502 if (block->args_empty() || block->isEntryBlock())
1503 return success();
1504
1505 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1506
1507 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1508 // A SPIR-V OpPhi instruction is of the syntax:
1509 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1510 // So we need to collect all predecessor blocks and the arguments they send
1511 // to this block.
1512 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1513 for (Block *mlirPredecessor : block->getPredecessors()) {
1514 auto *terminator = mlirPredecessor->getTerminator();
1515 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1516 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1517 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1518 // The predecessor here is the immediate one according to MLIR's IR
1519 // structure. It does not directly map to the incoming parent block for the
1520 // OpPhi instructions at SPIR-V binary level. This is because structured
1521 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1522 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1523 // the branch op jumping to the OpPhi's block then resides in the previous
1524 // structured control flow op's merge block.
1525 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1526 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1527 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1528 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1529 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1530 } else if (auto branchCondOp =
1531 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1532 std::optional<OperandRange> blockOperands;
1533 if (branchCondOp.getTrueTarget() == block) {
1534 blockOperands = branchCondOp.getTrueTargetOperands();
1535 } else {
1536 assert(branchCondOp.getFalseTarget() == block);
1537 blockOperands = branchCondOp.getFalseTargetOperands();
1538 }
1539 assert(!blockOperands->empty() &&
1540 "expected non-empty block operand range");
1541 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1542 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1543 std::optional<OperandRange> blockOperands;
1544 if (block == switchOp.getDefaultTarget()) {
1545 blockOperands = switchOp.getDefaultOperands();
1546 } else {
1547 SuccessorRange targets = switchOp.getTargets();
1548 auto it = llvm::find(targets, block);
1549 assert(it != targets.end());
1550 size_t index = std::distance(targets.begin(), it);
1551 blockOperands = switchOp.getTargetOperands(index);
1552 }
1553 assert(!blockOperands->empty() &&
1554 "expected non-empty block operand range");
1555 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1556 } else {
1557 return terminator->emitError("unimplemented terminator for Phi creation");
1558 }
1559 LLVM_DEBUG({
1560 llvm::dbgs() << " block arguments:\n";
1561 for (Value v : predecessors.back().second)
1562 llvm::dbgs() << " " << v << "\n";
1563 });
1564 }
1565
1566 // Then create OpPhi instruction for each of the block argument.
1567 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1568 BlockArgument arg = block->getArgument(argIndex);
1569
1570 // Get the type <id> and result <id> for this OpPhi instruction.
1571 uint32_t phiTypeID = 0;
1572 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1573 return failure();
1574 uint32_t phiID = getNextID();
1575
1576 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1577 << arg << " (id = " << phiID << ")\n");
1578
1579 // Prepare the (value <id>, parent block <id>) pairs.
1580 SmallVector<uint32_t, 8> phiArgs;
1581 phiArgs.push_back(phiTypeID);
1582 phiArgs.push_back(phiID);
1583
1584 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1585 Value value = predecessors[predIndex].second[argIndex];
1586 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1587 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1588 << ") value " << value << ' ');
1589 // Each pair is a value <id> ...
1590 uint32_t valueId = getValueID(value);
1591 if (valueId == 0) {
1592 // The op generating this value hasn't been visited yet so we don't have
1593 // an <id> assigned yet. Record this to fix up later.
1594 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1595 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1596 phiArgs.size());
1597 } else {
1598 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1599 }
1600 phiArgs.push_back(valueId);
1601 // ... and a parent block <id>.
1602 phiArgs.push_back(predBlockId);
1603 }
1604
1605 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1606 valueIDMap[arg] = phiID;
1607 }
1608
1609 return success();
1610}
1611
1612//===----------------------------------------------------------------------===//
1613// Operation
1614//===----------------------------------------------------------------------===//
1615
1616LogicalResult Serializer::encodeExtensionInstruction(
1617 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1618 ArrayRef<uint32_t> operands, SmallVectorImpl<uint32_t> &binary) {
1619 // Check if the extension has been imported.
1620 auto &setID = extendedInstSetIDMap[extensionSetName];
1621 if (!setID) {
1622 setID = getNextID();
1623 SmallVector<uint32_t, 16> importOperands;
1624 importOperands.push_back(setID);
1625 spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1626 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1627 importOperands);
1628 }
1629
1630 // The first two operands are the result type <id> and result <id>. The set
1631 // <id> and the opcode need to be insert after this.
1632 if (operands.size() < 2) {
1633 return op->emitError("extended instructions must have a result encoding");
1634 }
1635 SmallVector<uint32_t, 8> extInstOperands;
1636 extInstOperands.reserve(operands.size() + 2);
1637 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1638 extInstOperands.push_back(setID);
1639 extInstOperands.push_back(extensionOpcode);
1640 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1641 encodeInstructionInto(binary, spirv::Opcode::OpExtInst, extInstOperands);
1642 return success();
1643}
1644
1645LogicalResult Serializer::encodeExtensionInstruction(
1646 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1647 ArrayRef<uint32_t> operands) {
1648 if (failed(encodeExtensionInstruction(op, extensionSetName, extensionOpcode,
1649 operands, functionBody)))
1650 return failure();
1651
1652 if (extensionSetName == extTosa)
1653 updateTosaOpsMap(op);
1654
1655 return success();
1656}
1657
1658LogicalResult Serializer::processOperation(Operation *opInst) {
1659 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1660
1661 // First dispatch the ops that do not directly mirror an instruction from
1662 // the SPIR-V spec.
1664 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1665 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1666 .Case([&](spirv::BranchConditionalOp op) {
1667 return processBranchConditionalOp(op);
1668 })
1669 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1670 .Case([&](spirv::CompositeConstructOp op) {
1671 return processCompositeConstructOp(op);
1672 })
1673 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1674 return processConstantCompositeReplicateOp(op);
1675 })
1676 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1677 .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
1678 .Case([&](spirv::GraphEntryPointARMOp op) {
1679 return processGraphEntryPointARMOp(op);
1680 })
1681 .Case([&](spirv::GraphOutputsARMOp op) {
1682 return processGraphOutputsARMOp(op);
1683 })
1684 .Case([&](spirv::GlobalVariableOp op) {
1685 return processGlobalVariableOp(op);
1686 })
1687 .Case([&](spirv::GraphConstantARMOp op) {
1688 return processGraphConstantARMOp(op);
1689 })
1690 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1691 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1692 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1693 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1694 .Case([&](spirv::SpecConstantCompositeOp op) {
1695 return processSpecConstantCompositeOp(op);
1696 })
1697 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1698 return processSpecConstantCompositeReplicateOp(op);
1699 })
1700 .Case([&](spirv::SpecConstantOperationOp op) {
1701 return processSpecConstantOperationOp(op);
1702 })
1703 .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
1704 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1705 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1706
1707 // Then handle all the ops that directly mirror SPIR-V instructions with
1708 // auto-generated methods.
1709 .Default(
1710 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1711}
1712
1713LogicalResult
1714Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1715 Location loc = op.getLoc();
1716
1717 uint32_t resultTypeID = 0;
1718 if (failed(processType(loc, op.getType(), resultTypeID)))
1719 return failure();
1720
1721 uint32_t resultID = getNextID();
1722 valueIDMap[op.getResult()] = resultID;
1723
1724 SmallVector<uint32_t, 8> operands;
1725 operands.reserve(2 + op.getConstituents().size());
1726 operands.push_back(resultTypeID);
1727 operands.push_back(resultID);
1728 for (Value constituent : op.getConstituents()) {
1729 uint32_t id = getValueID(constituent);
1730 assert(id && "use before def!");
1731 operands.push_back(id);
1732 }
1733
1734 if (failed(emitDebugLine(functionBody, loc)))
1735 return failure();
1736
1737 encodeInstructionWithContinuationInto(
1738 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1739
1740 for (auto attr : op->getAttrs()) {
1741 if (failed(processDecoration(loc, resultID, attr)))
1742 return failure();
1743 }
1744
1745 return success();
1746}
1747
1748LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1749 StringRef extInstSet,
1750 uint32_t opcode) {
1751 SmallVector<uint32_t, 4> operands;
1752 Location loc = op->getLoc();
1753
1754 uint32_t resultID = 0;
1755 if (op->getNumResults() != 0) {
1756 uint32_t resultTypeID = 0;
1757 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1758 return failure();
1759 operands.push_back(resultTypeID);
1760
1761 resultID = getNextID();
1762 operands.push_back(resultID);
1763 valueIDMap[op->getResult(0)] = resultID;
1764 };
1765
1766 for (Value operand : op->getOperands())
1767 operands.push_back(getValueID(operand));
1768
1769 if (extInstSet != extTosa)
1770 // OpLine cannot be present in graphs
1771 if (failed(emitDebugLine(functionBody, loc)))
1772 return failure();
1773
1774 if (extInstSet.empty()) {
1775 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1776 operands);
1777 } else {
1778 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1779 return failure();
1780 }
1781
1782 if (op->getNumResults() != 0) {
1783 for (auto attr : op->getAttrs()) {
1784 if (failed(processDecoration(loc, resultID, attr)))
1785 return failure();
1786 }
1787 }
1788
1789 return success();
1790}
1791
1792void Serializer::updateTosaOpsMap(Operation *op) {
1793 if (!options.emitDebugInfo)
1794 return;
1795
1796 if (auto graphOp = dyn_cast<spirv::GraphARMOp>(op->getParentOp())) {
1797 if (uint32_t graphID = getFunctionID(graphOp.getName()))
1798 tosaOpsMap[graphID][op->getLoc()].insert(op);
1799 }
1800}
1801
1802LogicalResult Serializer::emitDecoration(uint32_t target,
1803 spirv::Decoration decoration,
1804 ArrayRef<uint32_t> params) {
1805 uint32_t wordCount = 3 + params.size();
1806 llvm::append_values(
1807 decorations,
1808 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1809 static_cast<uint32_t>(decoration));
1810 llvm::append_range(decorations, params);
1811 return success();
1812}
1813
1814LogicalResult Serializer::emitDecorationId(uint32_t target,
1815 spirv::Decoration decoration,
1816 ArrayRef<uint32_t> operandIds) {
1817 uint32_t wordCount = 3 + operandIds.size();
1818 llvm::append_values(
1819 decorations,
1820 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorateId), target,
1821 static_cast<uint32_t>(decoration));
1822 llvm::append_range(decorations, operandIds);
1823 return success();
1824}
1825
1826LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1827 Location loc) {
1828 if (!options.emitDebugInfo)
1829 return success();
1830
1831 if (lastProcessedWasMergeInst) {
1832 lastProcessedWasMergeInst = false;
1833 return success();
1834 }
1835
1836 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1837 if (fileLoc)
1838 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1839 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1840 return success();
1841}
1842} // namespace spirv
1843} // namespace mlir
return success()
ArrayAttr()
b getContext())
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
OpListType & getOperations()
Definition Block.h:147
Operation & back()
Definition Block.h:162
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
Definition Block.h:146
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType(unsigned) const
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
constexpr llvm::StringLiteral extTosa
Extension set name for TOSA ops.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147