MLIR 22.0.0git
Serializer.cpp
Go to the documentation of this file.
1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73static bool isZeroValue(Attribute attr) {
74 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
76 }
77 if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
79 }
80 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
82 }
83 if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
84 return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
85 }
86 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
87 return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
88 }
89 return false;
90}
91
92/// Move all functions declaration before functions definitions. In SPIR-V
93/// "declarations" are functions without a body and "definitions" functions
94/// with a body. This is stronger than necessary. It should be sufficient to
95/// ensure any declarations precede their uses and not all definitions, however
96/// this allows to avoid analysing every function in the module this way.
97static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
98 Block::OpListType &ops = moduleOp.getBody()->getOperations();
99 if (ops.empty())
100 return;
101 Operation &firstOp = ops.front();
102 for (Operation &op : llvm::drop_begin(ops))
103 if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
105 funcOp->moveBefore(&firstOp);
106}
107
108namespace mlir {
109namespace spirv {
110
111/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
112/// the given `binary` vector.
114 ArrayRef<uint32_t> operands) {
115 uint32_t wordCount = 1 + operands.size();
116 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
117 binary.append(operands.begin(), operands.end());
118}
119
120Serializer::Serializer(spirv::ModuleOp module,
121 const SerializationOptions &options)
122 : module(module), mlirBuilder(module.getContext()), options(options) {}
123
124LogicalResult Serializer::serialize() {
125 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
126
127 if (failed(module.verifyInvariants()))
128 return failure();
129
130 // TODO: handle the other sections
131 processCapability();
132 if (failed(processExtension())) {
133 return failure();
134 }
135 processMemoryModel();
136 processDebugInfo();
137
139
140 // Iterate over the module body to serialize it. Assumptions are that there is
141 // only one basic block in the moduleOp
142 for (auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
144 return failure();
145 }
146 }
147
148 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
149 return success();
150}
151
153 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
158
159 binary.clear();
160 binary.reserve(moduleSize);
161
162 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
163 nextID);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176}
177
178#ifndef NDEBUG
180 os << "\n= Value <id> Map =\n\n";
181 for (auto valueIDPair : valueIDMap) {
182 Value val = valueIDPair.first;
183 os << " " << val << " "
184 << "id = " << valueIDPair.second << ' ';
185 if (auto *op = val.getDefiningOp()) {
186 os << "from op '" << op->getName() << "'";
187 } else if (auto arg = dyn_cast<BlockArgument>(val)) {
188 Block *block = arg.getOwner();
189 os << "from argument of block " << block << ' ';
190 os << " in op '" << block->getParentOp()->getName() << "'";
191 }
192 os << '\n';
193 }
194}
195#endif
196
197//===----------------------------------------------------------------------===//
198// Module structure
199//===----------------------------------------------------------------------===//
200
201uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
202 auto funcID = funcIDMap.lookup(fnName);
203 if (!funcID) {
204 funcID = getNextID();
205 funcIDMap[fnName] = funcID;
206 }
207 return funcID;
208}
209
210void Serializer::processCapability() {
211 for (auto cap : module.getVceTriple()->getCapabilities())
212 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
213 {static_cast<uint32_t>(cap)});
214}
215
216void Serializer::processDebugInfo() {
217 if (!options.emitDebugInfo)
218 return;
219 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
220 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
221 fileID = getNextID();
222 SmallVector<uint32_t, 16> operands;
223 operands.push_back(fileID);
224 spirv::encodeStringLiteralInto(operands, fileName);
225 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
226 // TODO: Encode more debug instructions.
227}
228
229LogicalResult Serializer::processExtension() {
230 llvm::SmallVector<uint32_t, 16> extName;
231 llvm::SmallSet<Extension, 4> deducedExts(
232 llvm::from_range, module.getVceTriple()->getExtensions());
233 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
234 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
235 TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
236 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
237 return module.emitError(
238 "SPV_KHR_non_semantic_info extension not available");
239 deducedExts.insert(nonSemanticInfoExt);
240 }
241 for (spirv::Extension ext : deducedExts) {
242 extName.clear();
243 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
244 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
245 }
246 return success();
247}
248
249void Serializer::processMemoryModel() {
250 StringAttr memoryModelName = module.getMemoryModelAttrName();
251 auto mm = static_cast<uint32_t>(
252 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
253 .getValue());
254
255 StringAttr addressingModelName = module.getAddressingModelAttrName();
256 auto am = static_cast<uint32_t>(
257 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
258 .getValue());
259
260 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
261}
262
263static std::string getDecorationName(StringRef attrName) {
264 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
265 // expected FPFastMathMode.
266 if (attrName == "fp_fast_math_mode")
267 return "FPFastMathMode";
268 // similar here
269 if (attrName == "fp_rounding_mode")
270 return "FPRoundingMode";
271 // convertToCamelFromSnakeCase will not capitalize "INTEL".
272 if (attrName == "cache_control_load_intel")
273 return "CacheControlLoadINTEL";
274 if (attrName == "cache_control_store_intel")
275 return "CacheControlStoreINTEL";
276
277 return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
278}
279
280template <typename AttrTy, typename EmitF>
281static LogicalResult processDecorationList(Location loc, Decoration decoration,
282 Attribute attrList,
283 StringRef attrName, EmitF emitter) {
284 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
285 if (!arrayAttr) {
286 return emitError(loc, "expecting array attribute of ")
287 << attrName << " for " << stringifyDecoration(decoration);
288 }
289 if (arrayAttr.empty()) {
290 return emitError(loc, "expecting non-empty array attribute of ")
291 << attrName << " for " << stringifyDecoration(decoration);
292 }
293 for (Attribute attr : arrayAttr.getValue()) {
294 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
295 if (!cacheControlAttr) {
296 return emitError(loc, "expecting array attribute of ")
297 << attrName << " for " << stringifyDecoration(decoration);
298 }
299 // This named attribute encodes several decorations. Emit one per
300 // element in the array.
301 if (failed(emitter(cacheControlAttr)))
302 return failure();
303 }
304 return success();
305}
306
307LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
308 Decoration decoration,
309 Attribute attr) {
311 switch (decoration) {
312 case spirv::Decoration::LinkageAttributes: {
313 // Get the value of the Linkage Attributes
314 // e.g., LinkageAttributes=["linkageName", linkageType].
315 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
316 auto linkageName = linkageAttr.getLinkageName();
317 auto linkageType = linkageAttr.getLinkageType().getValue();
318 // Encode the Linkage Name (string literal to uint32_t).
319 spirv::encodeStringLiteralInto(args, linkageName);
320 // Encode LinkageType & Add the Linkagetype to the args.
321 args.push_back(static_cast<uint32_t>(linkageType));
322 break;
323 }
324 case spirv::Decoration::FPFastMathMode:
325 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
326 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
327 break;
328 }
329 return emitError(loc, "expected FPFastMathModeAttr attribute for ")
330 << stringifyDecoration(decoration);
331 case spirv::Decoration::FPRoundingMode:
332 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
333 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
334 break;
335 }
336 return emitError(loc, "expected FPRoundingModeAttr attribute for ")
337 << stringifyDecoration(decoration);
338 case spirv::Decoration::Binding:
339 case spirv::Decoration::DescriptorSet:
340 case spirv::Decoration::Location:
341 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
342 args.push_back(intAttr.getValue().getZExtValue());
343 break;
344 }
345 return emitError(loc, "expected integer attribute for ")
346 << stringifyDecoration(decoration);
347 case spirv::Decoration::BuiltIn:
348 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
349 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
350 if (enumVal) {
351 args.push_back(static_cast<uint32_t>(*enumVal));
352 break;
353 }
354 return emitError(loc, "invalid ")
355 << stringifyDecoration(decoration) << " decoration attribute "
356 << strAttr.getValue();
357 }
358 return emitError(loc, "expected string attribute for ")
359 << stringifyDecoration(decoration);
360 case spirv::Decoration::Aliased:
361 case spirv::Decoration::AliasedPointer:
362 case spirv::Decoration::Flat:
363 case spirv::Decoration::NonReadable:
364 case spirv::Decoration::NonWritable:
365 case spirv::Decoration::NoPerspective:
366 case spirv::Decoration::NoSignedWrap:
367 case spirv::Decoration::NoUnsignedWrap:
368 case spirv::Decoration::RelaxedPrecision:
369 case spirv::Decoration::Restrict:
370 case spirv::Decoration::RestrictPointer:
371 case spirv::Decoration::NoContraction:
372 case spirv::Decoration::Constant:
373 case spirv::Decoration::Block:
374 case spirv::Decoration::Invariant:
375 case spirv::Decoration::Patch:
376 // For unit attributes and decoration attributes, the args list
377 // has no values so we do nothing.
378 if (isa<UnitAttr, DecorationAttr>(attr))
379 break;
380 return emitError(loc,
381 "expected unit attribute or decoration attribute for ")
382 << stringifyDecoration(decoration);
383 case spirv::Decoration::CacheControlLoadINTEL:
385 loc, decoration, attr, "CacheControlLoadINTEL",
386 [&](CacheControlLoadINTELAttr attr) {
387 unsigned cacheLevel = attr.getCacheLevel();
388 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
389 return emitDecoration(
390 resultID, decoration,
391 {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
392 });
393 case spirv::Decoration::CacheControlStoreINTEL:
395 loc, decoration, attr, "CacheControlStoreINTEL",
396 [&](CacheControlStoreINTELAttr attr) {
397 unsigned cacheLevel = attr.getCacheLevel();
398 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
399 return emitDecoration(
400 resultID, decoration,
401 {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
402 });
403 default:
404 return emitError(loc, "unhandled decoration ")
405 << stringifyDecoration(decoration);
406 }
407 return emitDecoration(resultID, decoration, args);
408}
409
410LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
411 NamedAttribute attr) {
412 StringRef attrName = attr.getName().strref();
413 std::string decorationName = getDecorationName(attrName);
414 std::optional<Decoration> decoration =
415 spirv::symbolizeDecoration(decorationName);
416 if (!decoration) {
417 return emitError(
418 loc, "non-argument attributes expected to have snake-case-ified "
419 "decoration name, unhandled attribute with name : ")
420 << attrName;
421 }
422 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
423}
424
425LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
426 assert(!name.empty() && "unexpected empty string for OpName");
427 if (!options.emitSymbolName)
428 return success();
429
430 SmallVector<uint32_t, 4> nameOperands;
431 nameOperands.push_back(resultID);
432 spirv::encodeStringLiteralInto(nameOperands, name);
433 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
434 return success();
435}
436
437template <>
438LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
439 Location loc, spirv::ArrayType type, uint32_t resultID) {
440 if (unsigned stride = type.getArrayStride()) {
441 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
442 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
443 }
444 return success();
445}
446
447template <>
448LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
449 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
450 if (unsigned stride = type.getArrayStride()) {
451 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
452 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
453 }
454 return success();
455}
456
457LogicalResult Serializer::processMemberDecoration(
458 uint32_t structID,
459 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
461 {structID, memberDecoration.memberIndex,
462 static_cast<uint32_t>(memberDecoration.decoration)});
463 if (memberDecoration.hasValue()) {
464 args.push_back(
465 cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
466 }
467 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// Type
473//===----------------------------------------------------------------------===//
474
475// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
476// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
477// PushConstant Storage Classes must be explicitly laid out."
478bool Serializer::isInterfaceStructPtrType(Type type) const {
479 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
480 switch (ptrType.getStorageClass()) {
481 case spirv::StorageClass::PhysicalStorageBuffer:
482 case spirv::StorageClass::PushConstant:
483 case spirv::StorageClass::StorageBuffer:
484 case spirv::StorageClass::Uniform:
485 return isa<spirv::StructType>(ptrType.getPointeeType());
486 default:
487 break;
488 }
489 }
490 return false;
491}
492
493LogicalResult Serializer::processType(Location loc, Type type,
494 uint32_t &typeID) {
495 // Maintains a set of names for nested identified struct types. This is used
496 // to properly serialize recursive references.
497 SetVector<StringRef> serializationCtx;
498 return processTypeImpl(loc, type, typeID, serializationCtx);
499}
500
501LogicalResult
502Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
503 SetVector<StringRef> &serializationCtx) {
504
505 // Map unsigned integer types to singless integer types.
506 // This is needed otherwise the generated spirv assembly will contain
507 // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
508 // such module fails validation. Indeed at MLIR level the two types are
509 // different and lookup in the cache below misses.
510 // Note: This conversion needs to happen here before the type is looked up in
511 // the cache.
512 if (type.isUnsignedInteger()) {
513 type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
514 IntegerType::SignednessSemantics::Signless);
515 }
516
517 typeID = getTypeID(type);
518 if (typeID)
519 return success();
520
521 typeID = getNextID();
522 SmallVector<uint32_t, 4> operands;
523
524 operands.push_back(typeID);
525 auto typeEnum = spirv::Opcode::OpTypeVoid;
526 bool deferSerialization = false;
527
528 if ((isa<FunctionType>(type) &&
529 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
530 operands))) ||
531 (isa<GraphType>(type) &&
532 succeeded(
533 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
534 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
535 deferSerialization, serializationCtx))) {
536 if (deferSerialization)
537 return success();
538
539 typeIDMap[type] = typeID;
540
541 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
542
543 if (recursiveStructInfos.count(type) != 0) {
544 // This recursive struct type is emitted already, now the OpTypePointer
545 // instructions referring to recursive references are emitted as well.
546 for (auto &ptrInfo : recursiveStructInfos[type]) {
547 // TODO: This might not work if more than 1 recursive reference is
548 // present in the struct.
549 SmallVector<uint32_t, 4> ptrOperands;
550 ptrOperands.push_back(ptrInfo.pointerTypeID);
551 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
552 ptrOperands.push_back(typeIDMap[type]);
553
554 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
555 ptrOperands);
556 }
557
558 recursiveStructInfos[type].clear();
559 }
560
561 return success();
562 }
563
564 return emitError(loc, "failed to process type: ") << type;
565}
566
567LogicalResult Serializer::prepareBasicType(
568 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
569 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
570 SetVector<StringRef> &serializationCtx) {
571 deferSerialization = false;
572
573 if (isVoidType(type)) {
574 typeEnum = spirv::Opcode::OpTypeVoid;
575 return success();
576 }
577
578 if (auto intType = dyn_cast<IntegerType>(type)) {
579 if (intType.getWidth() == 1) {
580 typeEnum = spirv::Opcode::OpTypeBool;
581 return success();
582 }
583
584 typeEnum = spirv::Opcode::OpTypeInt;
585 operands.push_back(intType.getWidth());
586 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
587 // to preserve or validate.
588 // 0 indicates unsigned, or no signedness semantics
589 // 1 indicates signed semantics."
590 operands.push_back(intType.isSigned() ? 1 : 0);
591 return success();
592 }
593
594 if (auto floatType = dyn_cast<FloatType>(type)) {
595 typeEnum = spirv::Opcode::OpTypeFloat;
596 operands.push_back(floatType.getWidth());
597 if (floatType.isBF16()) {
598 operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
599 }
600 return success();
601 }
602
603 if (auto vectorType = dyn_cast<VectorType>(type)) {
604 uint32_t elementTypeID = 0;
605 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
606 serializationCtx))) {
607 return failure();
608 }
609 typeEnum = spirv::Opcode::OpTypeVector;
610 operands.push_back(elementTypeID);
611 operands.push_back(vectorType.getNumElements());
612 return success();
613 }
614
615 if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
616 typeEnum = spirv::Opcode::OpTypeImage;
617 uint32_t sampledTypeID = 0;
618 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
619 return failure();
620
621 llvm::append_values(operands, sampledTypeID,
622 static_cast<uint32_t>(imageType.getDim()),
623 static_cast<uint32_t>(imageType.getDepthInfo()),
624 static_cast<uint32_t>(imageType.getArrayedInfo()),
625 static_cast<uint32_t>(imageType.getSamplingInfo()),
626 static_cast<uint32_t>(imageType.getSamplerUseInfo()),
627 static_cast<uint32_t>(imageType.getImageFormat()));
628 return success();
629 }
630
631 if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
632 typeEnum = spirv::Opcode::OpTypeArray;
633 uint32_t elementTypeID = 0;
634 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
635 serializationCtx))) {
636 return failure();
637 }
638 operands.push_back(elementTypeID);
639 if (auto elementCountID = prepareConstantInt(
640 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
641 operands.push_back(elementCountID);
642 }
643 return processTypeDecoration(loc, arrayType, resultID);
644 }
645
646 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
647 uint32_t pointeeTypeID = 0;
648 spirv::StructType pointeeStruct =
649 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
650
651 if (pointeeStruct && pointeeStruct.isIdentified() &&
652 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
653 // A recursive reference to an enclosing struct is found.
654 //
655 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
656 // class as operands.
657 SmallVector<uint32_t, 2> forwardPtrOperands;
658 forwardPtrOperands.push_back(resultID);
659 forwardPtrOperands.push_back(
660 static_cast<uint32_t>(ptrType.getStorageClass()));
661
662 encodeInstructionInto(typesGlobalValues,
663 spirv::Opcode::OpTypeForwardPointer,
664 forwardPtrOperands);
665
666 // 2. Find the pointee (enclosing) struct.
667 auto structType = spirv::StructType::getIdentified(
668 module.getContext(), pointeeStruct.getIdentifier());
669
670 if (!structType)
671 return failure();
672
673 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
674 // as deferred.
675 deferSerialization = true;
676
677 // 4. Record the info needed to emit the deferred OpTypePointer
678 // instruction when the enclosing struct is completely serialized.
679 recursiveStructInfos[structType].push_back(
680 {resultID, ptrType.getStorageClass()});
681 } else {
682 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
683 serializationCtx)))
684 return failure();
685 }
686
687 typeEnum = spirv::Opcode::OpTypePointer;
688 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
689 operands.push_back(pointeeTypeID);
690
691 // TODO: Now struct decorations are supported this code may not be
692 // necessary. However, it is left to support backwards compatibility.
693 // Ideally, Block decorations should be inserted when converting to SPIR-V.
694 if (isInterfaceStructPtrType(ptrType)) {
695 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
696 if (!structType.hasDecoration(spirv::Decoration::Block))
697 if (failed(emitDecoration(getTypeID(pointeeStruct),
698 spirv::Decoration::Block)))
699 return emitError(loc, "cannot decorate ")
700 << pointeeStruct << " with Block decoration";
701 }
702
703 return success();
704 }
705
706 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
707 uint32_t elementTypeID = 0;
708 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
709 elementTypeID, serializationCtx))) {
710 return failure();
711 }
712 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
713 operands.push_back(elementTypeID);
714 return processTypeDecoration(loc, runtimeArrayType, resultID);
715 }
716
717 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
718 typeEnum = spirv::Opcode::OpTypeSampledImage;
719 uint32_t imageTypeID = 0;
720 if (failed(
721 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
722 return failure();
723 }
724 operands.push_back(imageTypeID);
725 return success();
726 }
727
728 if (auto structType = dyn_cast<spirv::StructType>(type)) {
729 if (structType.isIdentified()) {
730 if (failed(processName(resultID, structType.getIdentifier())))
731 return failure();
732 serializationCtx.insert(structType.getIdentifier());
733 }
734
735 bool hasOffset = structType.hasOffset();
736 for (auto elementIndex :
737 llvm::seq<uint32_t>(0, structType.getNumElements())) {
738 uint32_t elementTypeID = 0;
739 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
740 elementTypeID, serializationCtx))) {
741 return failure();
742 }
743 operands.push_back(elementTypeID);
744 if (hasOffset) {
745 auto intType = IntegerType::get(structType.getContext(), 32);
746 // Decorate each struct member with an offset
747 spirv::StructType::MemberDecorationInfo offsetDecoration{
748 elementIndex, spirv::Decoration::Offset,
749 IntegerAttr::get(intType,
750 structType.getMemberOffset(elementIndex))};
751 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
752 return emitError(loc, "cannot decorate ")
753 << elementIndex << "-th member of " << structType
754 << " with its offset";
755 }
756 }
757 }
758 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
759 structType.getMemberDecorations(memberDecorations);
760
761 for (auto &memberDecoration : memberDecorations) {
762 if (failed(processMemberDecoration(resultID, memberDecoration))) {
763 return emitError(loc, "cannot decorate ")
764 << static_cast<uint32_t>(memberDecoration.memberIndex)
765 << "-th member of " << structType << " with "
766 << stringifyDecoration(memberDecoration.decoration);
767 }
768 }
769
770 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
771 structType.getStructDecorations(structDecorations);
772
773 for (spirv::StructType::StructDecorationInfo &structDecoration :
774 structDecorations) {
775 if (failed(processDecorationAttr(loc, resultID,
776 structDecoration.decoration,
777 structDecoration.decorationValue))) {
778 return emitError(loc, "cannot decorate struct ")
779 << structType << " with "
780 << stringifyDecoration(structDecoration.decoration);
781 }
782 }
783
784 typeEnum = spirv::Opcode::OpTypeStruct;
785
786 if (structType.isIdentified())
787 serializationCtx.remove(structType.getIdentifier());
788
789 return success();
790 }
791
792 if (auto cooperativeMatrixType =
793 dyn_cast<spirv::CooperativeMatrixType>(type)) {
794 uint32_t elementTypeID = 0;
795 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
796 elementTypeID, serializationCtx))) {
797 return failure();
798 }
799 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
800 auto getConstantOp = [&](uint32_t id) {
801 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
802 return prepareConstantInt(loc, attr);
803 };
804 llvm::append_values(
805 operands, elementTypeID,
806 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
807 getConstantOp(cooperativeMatrixType.getRows()),
808 getConstantOp(cooperativeMatrixType.getColumns()),
809 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
810 return success();
811 }
812
813 if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
814 uint32_t elementTypeID = 0;
815 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
816 serializationCtx))) {
817 return failure();
818 }
819 typeEnum = spirv::Opcode::OpTypeMatrix;
820 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
821 return success();
822 }
823
824 if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
825 uint32_t elementTypeID = 0;
826 uint32_t rank = 0;
827 uint32_t shapeID = 0;
828 uint32_t rankID = 0;
829 if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
830 elementTypeID, serializationCtx))) {
831 return failure();
832 }
833 if (tensorArmType.hasRank()) {
834 ArrayRef<int64_t> dims = tensorArmType.getShape();
835 rank = dims.size();
836 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
837 if (rankID == 0) {
838 return failure();
839 }
840
841 bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
842 if (rank > 0 && shaped) {
843 auto I32Type = IntegerType::get(type.getContext(), 32);
844 auto shapeType = ArrayType::get(I32Type, rank);
845 if (rank == 1) {
846 SmallVector<uint64_t, 1> index(rank);
847 shapeID = prepareDenseElementsConstant(
848 loc, shapeType,
849 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
850 index);
851 } else {
852 shapeID = prepareArrayConstant(
853 loc, shapeType,
854 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
855 }
856 if (shapeID == 0) {
857 return failure();
858 }
859 }
860 }
861 typeEnum = spirv::Opcode::OpTypeTensorARM;
862 operands.push_back(elementTypeID);
863 if (rankID == 0)
864 return success();
865 operands.push_back(rankID);
866 if (shapeID == 0)
867 return success();
868 operands.push_back(shapeID);
869 return success();
870 }
871
872 // TODO: Handle other types.
873 return emitError(loc, "unhandled type in serialization: ") << type;
874}
875
876LogicalResult
877Serializer::prepareFunctionType(Location loc, FunctionType type,
878 spirv::Opcode &typeEnum,
879 SmallVectorImpl<uint32_t> &operands) {
880 typeEnum = spirv::Opcode::OpTypeFunction;
881 assert(type.getNumResults() <= 1 &&
882 "serialization supports only a single return value");
883 uint32_t resultID = 0;
884 if (failed(processType(
885 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
886 resultID))) {
887 return failure();
888 }
889 operands.push_back(resultID);
890 for (auto &res : type.getInputs()) {
891 uint32_t argTypeID = 0;
892 if (failed(processType(loc, res, argTypeID))) {
893 return failure();
894 }
895 operands.push_back(argTypeID);
896 }
897 return success();
898}
899
900LogicalResult
901Serializer::prepareGraphType(Location loc, GraphType type,
902 spirv::Opcode &typeEnum,
903 SmallVectorImpl<uint32_t> &operands) {
904 typeEnum = spirv::Opcode::OpTypeGraphARM;
905 assert(type.getNumResults() >= 1 &&
906 "serialization requires at least a return value");
907
908 operands.push_back(type.getNumInputs());
909
910 for (Type argType : type.getInputs()) {
911 uint32_t argTypeID = 0;
912 if (failed(processType(loc, argType, argTypeID)))
913 return failure();
914 operands.push_back(argTypeID);
915 }
916
917 for (Type resType : type.getResults()) {
918 uint32_t resTypeID = 0;
919 if (failed(processType(loc, resType, resTypeID)))
920 return failure();
921 operands.push_back(resTypeID);
922 }
923
924 return success();
925}
926
927//===----------------------------------------------------------------------===//
928// Constant
929//===----------------------------------------------------------------------===//
930
931uint32_t Serializer::prepareConstant(Location loc, Type constType,
932 Attribute valueAttr) {
933 if (auto id = prepareConstantScalar(loc, valueAttr)) {
934 return id;
935 }
936
937 // This is a composite literal. We need to handle each component separately
938 // and then emit an OpConstantComposite for the whole.
939
940 if (auto id = getConstantID(valueAttr)) {
941 return id;
942 }
943
944 uint32_t typeID = 0;
945 if (failed(processType(loc, constType, typeID))) {
946 return 0;
947 }
948
949 uint32_t resultID = 0;
950 if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
951 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
952 SmallVector<uint64_t, 4> index(rank);
953 resultID = prepareDenseElementsConstant(loc, constType, attr,
954 /*dim=*/0, index);
955 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
956 resultID = prepareArrayConstant(loc, constType, arrayAttr);
957 }
958
959 if (resultID == 0) {
960 emitError(loc, "cannot serialize attribute: ") << valueAttr;
961 return 0;
962 }
963
964 constIDMap[valueAttr] = resultID;
965 return resultID;
966}
967
968uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
969 ArrayAttr attr) {
970 uint32_t typeID = 0;
971 if (failed(processType(loc, constType, typeID))) {
972 return 0;
973 }
974
975 uint32_t resultID = getNextID();
976 SmallVector<uint32_t, 4> operands = {typeID, resultID};
977 operands.reserve(attr.size() + 2);
978 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
979 for (Attribute elementAttr : attr) {
980 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
981 operands.push_back(elementID);
982 } else {
983 return 0;
984 }
985 }
986 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
987 encodeInstructionInto(typesGlobalValues, opcode, operands);
988
989 return resultID;
990}
991
992// TODO: Turn the below function into iterative function, instead of
993// recursive function.
994uint32_t
995Serializer::prepareDenseElementsConstant(Location loc, Type constType,
996 DenseElementsAttr valueAttr, int dim,
997 MutableArrayRef<uint64_t> index) {
998 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
999 assert(dim <= shapedType.getRank());
1000 if (shapedType.getRank() == dim) {
1001 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1002 return attr.getType().getElementType().isInteger(1)
1003 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1004 : prepareConstantInt(loc,
1005 attr.getValues<IntegerAttr>()[index]);
1006 }
1007 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1008 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1009 }
1010 return 0;
1011 }
1012
1013 uint32_t typeID = 0;
1014 if (failed(processType(loc, constType, typeID))) {
1015 return 0;
1016 }
1017
1018 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1019 uint32_t resultID = getNextID();
1020 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1021 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1022 if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1023 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1024 if (!innerShape.empty())
1025 elementType = spirv::TensorArmType::get(innerShape, elementType);
1026 }
1027
1028 // "If the Result Type is a cooperative matrix type, then there must be only
1029 // one Constituent, with scalar type matching the cooperative matrix Component
1030 // Type, and all components of the matrix are initialized to that value."
1031 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
1032 if (isa<spirv::CooperativeMatrixType>(constType)) {
1033 if (!valueAttr.isSplat()) {
1034 emitError(
1035 loc,
1036 "cannot serialize a non-splat value for a cooperative matrix type");
1037 return 0;
1038 }
1039 // numberOfConstituents is 1, so we only need one more elements in the
1040 // SmallVector, so the total is 3 (1 + 2).
1041 operands.reserve(3);
1042 // We set dim directly to `shapedType.getRank()` so the recursive call
1043 // directly returns the scalar type.
1044 if (auto elementID = prepareDenseElementsConstant(
1045 loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
1046 operands.push_back(elementID);
1047 } else {
1048 return 0;
1049 }
1050 } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
1051 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1052 {typeID, resultID});
1053 return resultID;
1054 } else {
1055 operands.reserve(numberOfConstituents + 2);
1056 for (int i = 0; i < numberOfConstituents; ++i) {
1057 index[dim] = i;
1058 if (auto elementID = prepareDenseElementsConstant(
1059 loc, elementType, valueAttr, dim + 1, index)) {
1060 operands.push_back(elementID);
1061 } else {
1062 return 0;
1063 }
1064 }
1065 }
1066 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1067 encodeInstructionInto(typesGlobalValues, opcode, operands);
1068
1069 return resultID;
1070}
1071
1072uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1073 bool isSpec) {
1074 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1075 return prepareConstantFp(loc, floatAttr, isSpec);
1076 }
1077 if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1078 return prepareConstantBool(loc, boolAttr, isSpec);
1079 }
1080 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1081 return prepareConstantInt(loc, intAttr, isSpec);
1082 }
1083
1084 return 0;
1085}
1086
1087uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1088 bool isSpec) {
1089 if (!isSpec) {
1090 // We can de-duplicate normal constants, but not specialization constants.
1091 if (auto id = getConstantID(boolAttr)) {
1092 return id;
1093 }
1094 }
1095
1096 // Process the type for this bool literal
1097 uint32_t typeID = 0;
1098 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
1099 return 0;
1100 }
1101
1102 auto resultID = getNextID();
1103 auto opcode = boolAttr.getValue()
1104 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1105 : spirv::Opcode::OpConstantTrue)
1106 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1107 : spirv::Opcode::OpConstantFalse);
1108 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1109
1110 if (!isSpec) {
1111 constIDMap[boolAttr] = resultID;
1112 }
1113 return resultID;
1114}
1115
1116uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1117 bool isSpec) {
1118 if (!isSpec) {
1119 // We can de-duplicate normal constants, but not specialization constants.
1120 if (auto id = getConstantID(intAttr)) {
1121 return id;
1122 }
1123 }
1124
1125 // Process the type for this integer literal
1126 uint32_t typeID = 0;
1127 if (failed(processType(loc, intAttr.getType(), typeID))) {
1128 return 0;
1129 }
1130
1131 auto resultID = getNextID();
1132 APInt value = intAttr.getValue();
1133 unsigned bitwidth = value.getBitWidth();
1134 bool isSigned = intAttr.getType().isSignedInteger();
1135 auto opcode =
1136 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1137
1138 switch (bitwidth) {
1139 // According to SPIR-V spec, "When the type's bit width is less than
1140 // 32-bits, the literal's value appears in the low-order bits of the word,
1141 // and the high-order bits must be 0 for a floating-point type, or 0 for an
1142 // integer type with Signedness of 0, or sign extended when Signedness
1143 // is 1."
1144 case 32:
1145 case 16:
1146 case 8: {
1147 uint32_t word = 0;
1148 if (isSigned) {
1149 word = static_cast<int32_t>(value.getSExtValue());
1150 } else {
1151 word = static_cast<uint32_t>(value.getZExtValue());
1152 }
1153 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1154 } break;
1155 // According to SPIR-V spec: "When the type's bit width is larger than one
1156 // word, the literal’s low-order words appear first."
1157 case 64: {
1158 struct DoubleWord {
1159 uint32_t word1;
1160 uint32_t word2;
1161 } words;
1162 if (isSigned) {
1163 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1164 } else {
1165 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1166 }
1167 encodeInstructionInto(typesGlobalValues, opcode,
1168 {typeID, resultID, words.word1, words.word2});
1169 } break;
1170 default: {
1171 std::string valueStr;
1172 llvm::raw_string_ostream rss(valueStr);
1173 value.print(rss, /*isSigned=*/false);
1174
1175 emitError(loc, "cannot serialize ")
1176 << bitwidth << "-bit integer literal: " << valueStr;
1177 return 0;
1178 }
1179 }
1180
1181 if (!isSpec) {
1182 constIDMap[intAttr] = resultID;
1183 }
1184 return resultID;
1185}
1186
1187uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1188 IntegerAttr intAttr) {
1189 // De-duplicate graph constants.
1190 if (uint32_t id = getGraphConstantARMId(intAttr)) {
1191 return id;
1192 }
1193
1194 // Process the type for this graph constant.
1195 uint32_t typeID = 0;
1196 if (failed(processType(loc, graphConstType, typeID))) {
1197 return 0;
1198 }
1199
1200 uint32_t resultID = getNextID();
1201 APInt value = intAttr.getValue();
1202 unsigned bitwidth = value.getBitWidth();
1203 if (bitwidth > 32) {
1204 emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
1205 << bitwidth << " bits";
1206 return 0;
1207 }
1208 bool isSigned = value.isSignedIntN(bitwidth);
1209
1210 uint32_t word = 0;
1211 if (isSigned) {
1212 word = static_cast<int32_t>(value.getSExtValue());
1213 } else {
1214 word = static_cast<uint32_t>(value.getZExtValue());
1215 }
1216 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
1217 {typeID, resultID, word});
1218 graphConstIDMap[intAttr] = resultID;
1219 return resultID;
1220}
1221
1222uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1223 bool isSpec) {
1224 if (!isSpec) {
1225 // We can de-duplicate normal constants, but not specialization constants.
1226 if (auto id = getConstantID(floatAttr)) {
1227 return id;
1228 }
1229 }
1230
1231 // Process the type for this float literal
1232 uint32_t typeID = 0;
1233 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1234 return 0;
1235 }
1236
1237 auto resultID = getNextID();
1238 APFloat value = floatAttr.getValue();
1239 const llvm::fltSemantics *semantics = &value.getSemantics();
1240
1241 auto opcode =
1242 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1243
1244 if (semantics == &APFloat::IEEEsingle()) {
1245 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1246 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1247 } else if (semantics == &APFloat::IEEEdouble()) {
1248 struct DoubleWord {
1249 uint32_t word1;
1250 uint32_t word2;
1251 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1252 encodeInstructionInto(typesGlobalValues, opcode,
1253 {typeID, resultID, words.word1, words.word2});
1254 } else if (semantics == &APFloat::IEEEhalf() ||
1255 semantics == &APFloat::BFloat()) {
1256 uint32_t word =
1257 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1258 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1259 } else {
1260 std::string valueStr;
1261 llvm::raw_string_ostream rss(valueStr);
1262 value.print(rss);
1263
1264 emitError(loc, "cannot serialize ")
1265 << floatAttr.getType() << "-typed float literal: " << valueStr;
1266 return 0;
1267 }
1268
1269 if (!isSpec) {
1270 constIDMap[floatAttr] = resultID;
1271 }
1272 return resultID;
1273}
1274
1275// Returns type of attribute. In case of a TypedAttr this will simply return
1276// the type. But for an ArrayAttr which is untyped and can be multidimensional
1277// it creates the ArrayType recursively.
1279 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1280 return typedAttr.getType();
1281 }
1282
1283 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1284 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1285 }
1286
1287 return nullptr;
1288}
1289
1290uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1291 Type resultType,
1292 Attribute valueAttr) {
1293 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1294 if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1295 return id;
1296 }
1297
1298 uint32_t typeID = 0;
1299 if (failed(processType(loc, resultType, typeID))) {
1300 return 0;
1301 }
1302
1303 Type valueType = getValueType(valueAttr);
1304 if (!valueAttr)
1305 return 0;
1306
1307 auto compositeType = dyn_cast<CompositeType>(resultType);
1308 if (!compositeType)
1309 return 0;
1310 Type elementType = compositeType.getElementType(0);
1311
1312 uint32_t constandID;
1313 if (elementType == valueType) {
1314 constandID = prepareConstant(loc, elementType, valueAttr);
1315 } else {
1316 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1317 }
1318
1319 uint32_t resultID = getNextID();
1320 if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1321 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1322 {typeID, resultID});
1323 } else {
1324 encodeInstructionInto(typesGlobalValues,
1325 spirv::Opcode::OpConstantCompositeReplicateEXT,
1326 {typeID, resultID, constandID});
1327 }
1328
1329 constCompositeReplicateIDMap[valueTypePair] = resultID;
1330 return resultID;
1331}
1332
1333//===----------------------------------------------------------------------===//
1334// Control flow
1335//===----------------------------------------------------------------------===//
1336
1337uint32_t Serializer::getOrCreateBlockID(Block *block) {
1338 if (uint32_t id = getBlockID(block))
1339 return id;
1340 return blockIDMap[block] = getNextID();
1341}
1342
1343#ifndef NDEBUG
1344void Serializer::printBlock(Block *block, raw_ostream &os) {
1345 os << "block " << block << " (id = ";
1346 if (uint32_t id = getBlockID(block))
1347 os << id;
1348 else
1349 os << "unknown";
1350 os << ")\n";
1351}
1352#endif
1353
1354LogicalResult
1355Serializer::processBlock(Block *block, bool omitLabel,
1356 function_ref<LogicalResult()> emitMerge) {
1357 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1358 LLVM_DEBUG(block->print(llvm::dbgs()));
1359 LLVM_DEBUG(llvm::dbgs() << '\n');
1360 if (!omitLabel) {
1361 uint32_t blockID = getOrCreateBlockID(block);
1362 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1363
1364 // Emit OpLabel for this block.
1365 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1366 }
1367
1368 // Emit OpPhi instructions for block arguments, if any.
1369 if (failed(emitPhiForBlockArguments(block)))
1370 return failure();
1371
1372 // If we need to emit merge instructions, it must happen in this block. Check
1373 // whether we have other structured control flow ops, which will be expanded
1374 // into multiple basic blocks. If that's the case, we need to emit the merge
1375 // right now and then create new blocks for further serialization of the ops
1376 // in this block.
1377 if (emitMerge &&
1378 llvm::any_of(block->getOperations(),
1379 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1380 if (failed(emitMerge()))
1381 return failure();
1382 emitMerge = nullptr;
1383
1384 // Start a new block for further serialization.
1385 uint32_t blockID = getNextID();
1386 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1387 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1388 }
1389
1390 // Process each op in this block except the terminator.
1391 for (Operation &op : llvm::drop_end(*block)) {
1392 if (failed(processOperation(&op)))
1393 return failure();
1394 }
1395
1396 // Process the terminator.
1397 if (emitMerge)
1398 if (failed(emitMerge()))
1399 return failure();
1400 if (failed(processOperation(&block->back())))
1401 return failure();
1402
1403 return success();
1404}
1405
1406LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1407 // Nothing to do if this block has no arguments or it's the entry block, which
1408 // always has the same arguments as the function signature.
1409 if (block->args_empty() || block->isEntryBlock())
1410 return success();
1411
1412 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1413
1414 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1415 // A SPIR-V OpPhi instruction is of the syntax:
1416 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1417 // So we need to collect all predecessor blocks and the arguments they send
1418 // to this block.
1419 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1420 for (Block *mlirPredecessor : block->getPredecessors()) {
1421 auto *terminator = mlirPredecessor->getTerminator();
1422 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1423 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1424 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1425 // The predecessor here is the immediate one according to MLIR's IR
1426 // structure. It does not directly map to the incoming parent block for the
1427 // OpPhi instructions at SPIR-V binary level. This is because structured
1428 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1429 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1430 // the branch op jumping to the OpPhi's block then resides in the previous
1431 // structured control flow op's merge block.
1432 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1433 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1434 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1435 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1436 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1437 } else if (auto branchCondOp =
1438 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1439 std::optional<OperandRange> blockOperands;
1440 if (branchCondOp.getTrueTarget() == block) {
1441 blockOperands = branchCondOp.getTrueTargetOperands();
1442 } else {
1443 assert(branchCondOp.getFalseTarget() == block);
1444 blockOperands = branchCondOp.getFalseTargetOperands();
1445 }
1446
1447 assert(!blockOperands->empty() &&
1448 "expected non-empty block operand range");
1449 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1450 } else {
1451 return terminator->emitError("unimplemented terminator for Phi creation");
1452 }
1453 LLVM_DEBUG({
1454 llvm::dbgs() << " block arguments:\n";
1455 for (Value v : predecessors.back().second)
1456 llvm::dbgs() << " " << v << "\n";
1457 });
1458 }
1459
1460 // Then create OpPhi instruction for each of the block argument.
1461 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1462 BlockArgument arg = block->getArgument(argIndex);
1463
1464 // Get the type <id> and result <id> for this OpPhi instruction.
1465 uint32_t phiTypeID = 0;
1466 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1467 return failure();
1468 uint32_t phiID = getNextID();
1469
1470 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1471 << arg << " (id = " << phiID << ")\n");
1472
1473 // Prepare the (value <id>, parent block <id>) pairs.
1474 SmallVector<uint32_t, 8> phiArgs;
1475 phiArgs.push_back(phiTypeID);
1476 phiArgs.push_back(phiID);
1477
1478 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1479 Value value = predecessors[predIndex].second[argIndex];
1480 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1481 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1482 << ") value " << value << ' ');
1483 // Each pair is a value <id> ...
1484 uint32_t valueId = getValueID(value);
1485 if (valueId == 0) {
1486 // The op generating this value hasn't been visited yet so we don't have
1487 // an <id> assigned yet. Record this to fix up later.
1488 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1489 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1490 phiArgs.size());
1491 } else {
1492 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1493 }
1494 phiArgs.push_back(valueId);
1495 // ... and a parent block <id>.
1496 phiArgs.push_back(predBlockId);
1497 }
1498
1499 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1500 valueIDMap[arg] = phiID;
1501 }
1502
1503 return success();
1504}
1505
1506//===----------------------------------------------------------------------===//
1507// Operation
1508//===----------------------------------------------------------------------===//
1509
1510LogicalResult Serializer::encodeExtensionInstruction(
1511 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1512 ArrayRef<uint32_t> operands) {
1513 // Check if the extension has been imported.
1514 auto &setID = extendedInstSetIDMap[extensionSetName];
1515 if (!setID) {
1516 setID = getNextID();
1517 SmallVector<uint32_t, 16> importOperands;
1518 importOperands.push_back(setID);
1519 spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1520 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1521 importOperands);
1522 }
1523
1524 // The first two operands are the result type <id> and result <id>. The set
1525 // <id> and the opcode need to be insert after this.
1526 if (operands.size() < 2) {
1527 return op->emitError("extended instructions must have a result encoding");
1528 }
1529 SmallVector<uint32_t, 8> extInstOperands;
1530 extInstOperands.reserve(operands.size() + 2);
1531 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1532 extInstOperands.push_back(setID);
1533 extInstOperands.push_back(extensionOpcode);
1534 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1535 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1536 extInstOperands);
1537 return success();
1538}
1539
1540LogicalResult Serializer::processOperation(Operation *opInst) {
1541 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1542
1543 // First dispatch the ops that do not directly mirror an instruction from
1544 // the SPIR-V spec.
1546 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1547 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1548 .Case([&](spirv::BranchConditionalOp op) {
1549 return processBranchConditionalOp(op);
1550 })
1551 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1552 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1553 return processConstantCompositeReplicateOp(op);
1554 })
1555 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1556 .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
1557 .Case([&](spirv::GraphEntryPointARMOp op) {
1558 return processGraphEntryPointARMOp(op);
1559 })
1560 .Case([&](spirv::GraphOutputsARMOp op) {
1561 return processGraphOutputsARMOp(op);
1562 })
1563 .Case([&](spirv::GlobalVariableOp op) {
1564 return processGlobalVariableOp(op);
1565 })
1566 .Case([&](spirv::GraphConstantARMOp op) {
1567 return processGraphConstantARMOp(op);
1568 })
1569 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1570 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1571 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1572 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1573 .Case([&](spirv::SpecConstantCompositeOp op) {
1574 return processSpecConstantCompositeOp(op);
1575 })
1576 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1577 return processSpecConstantCompositeReplicateOp(op);
1578 })
1579 .Case([&](spirv::SpecConstantOperationOp op) {
1580 return processSpecConstantOperationOp(op);
1581 })
1582 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1583 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1584
1585 // Then handle all the ops that directly mirror SPIR-V instructions with
1586 // auto-generated methods.
1587 .Default(
1588 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1589}
1590
1591LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1592 StringRef extInstSet,
1593 uint32_t opcode) {
1594 SmallVector<uint32_t, 4> operands;
1595 Location loc = op->getLoc();
1596
1597 uint32_t resultID = 0;
1598 if (op->getNumResults() != 0) {
1599 uint32_t resultTypeID = 0;
1600 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1601 return failure();
1602 operands.push_back(resultTypeID);
1603
1604 resultID = getNextID();
1605 operands.push_back(resultID);
1606 valueIDMap[op->getResult(0)] = resultID;
1607 };
1608
1609 for (Value operand : op->getOperands())
1610 operands.push_back(getValueID(operand));
1611
1612 if (failed(emitDebugLine(functionBody, loc)))
1613 return failure();
1614
1615 if (extInstSet.empty()) {
1616 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1617 operands);
1618 } else {
1619 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1620 return failure();
1621 }
1622
1623 if (op->getNumResults() != 0) {
1624 for (auto attr : op->getAttrs()) {
1625 if (failed(processDecoration(loc, resultID, attr)))
1626 return failure();
1627 }
1628 }
1629
1630 return success();
1631}
1632
1633LogicalResult Serializer::emitDecoration(uint32_t target,
1634 spirv::Decoration decoration,
1635 ArrayRef<uint32_t> params) {
1636 uint32_t wordCount = 3 + params.size();
1637 llvm::append_values(
1638 decorations,
1639 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1640 static_cast<uint32_t>(decoration));
1641 llvm::append_range(decorations, params);
1642 return success();
1643}
1644
1645LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1646 Location loc) {
1647 if (!options.emitDebugInfo)
1648 return success();
1649
1650 if (lastProcessedWasMergeInst) {
1651 lastProcessedWasMergeInst = false;
1652 return success();
1653 }
1654
1655 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1656 if (fileLoc)
1657 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1658 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1659 return success();
1660}
1661} // namespace spirv
1662} // namespace mlir
return success()
ArrayAttr()
b getContext())
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
static bool isZeroValue(Value val)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Definition Value.h:324
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:240
OpListType & getOperations()
Definition Block.h:137
Operation & back()
Definition Block.h:152
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:99
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
Definition Block.h:136
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:512
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152