MLIR 23.0.0git
Serializer.cpp
Go to the documentation of this file.
1//===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the MLIR SPIR-V module to SPIR-V binary serializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
27#include <cstdint>
28#include <optional>
29
30#define DEBUG_TYPE "spirv-serialization"
31
32using namespace mlir;
33
34/// Returns the merge block if the given `op` is a structured control flow op.
35/// Otherwise returns nullptr.
37 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
41 return nullptr;
42}
43
44/// Given a predecessor `block` for a block with arguments, returns the block
45/// that should be used as the parent block for SPIR-V OpPhi instructions
46/// corresponding to the block arguments.
48 // If the predecessor block in question is the entry block for a
49 // spirv.mlir.loop, we jump to this spirv.mlir.loop from its enclosing block.
50 if (block->isEntryBlock()) {
51 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
52 // Then the incoming parent block for OpPhi should be the merge block of
53 // the structured control flow op before this loop.
54 Operation *op = loopOp.getOperation();
55 while ((op = op->getPrevNode()) != nullptr)
56 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
57 return incomingBlock;
58 // Or the enclosing block itself if no structured control flow ops
59 // exists before this loop.
60 return loopOp->getBlock();
61 }
62 }
63
64 // Otherwise, we jump from the given predecessor block. Try to see if there is
65 // a structured control flow op inside it.
66 for (Operation &op : llvm::reverse(block->getOperations())) {
67 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
68 return incomingBlock;
69 }
70 return block;
71}
72
73static bool isZeroValue(Attribute attr) {
74 if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
76 }
77 if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
79 }
80 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
82 }
83 if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
84 return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
85 }
86 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
87 return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
88 }
89 return false;
90}
91
92/// Move all functions declaration before functions definitions. In SPIR-V
93/// "declarations" are functions without a body and "definitions" functions
94/// with a body. This is stronger than necessary. It should be sufficient to
95/// ensure any declarations precede their uses and not all definitions, however
96/// this allows to avoid analysing every function in the module this way.
97static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
98 Block::OpListType &ops = moduleOp.getBody()->getOperations();
99 if (ops.empty())
100 return;
101 Operation &firstOp = ops.front();
102 for (Operation &op : llvm::drop_begin(ops))
103 if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
105 funcOp->moveBefore(&firstOp);
106}
107
108namespace mlir {
109namespace spirv {
110
111/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
112/// the given `binary` vector.
114 ArrayRef<uint32_t> operands) {
115 uint32_t wordCount = 1 + operands.size();
116 binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
117 binary.append(operands.begin(), operands.end());
118}
119
120Serializer::Serializer(spirv::ModuleOp module,
121 const SerializationOptions &options)
122 : module(module), mlirBuilder(module.getContext()), options(options) {}
123
124LogicalResult Serializer::serialize() {
125 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
126
127 if (failed(module.verifyInvariants()))
128 return failure();
129
130 // TODO: handle the other sections
131 processCapability();
132 if (failed(processExtension())) {
133 return failure();
134 }
135 processMemoryModel();
136 processDebugInfo();
137
139
140 // Iterate over the module body to serialize it. Assumptions are that there is
141 // only one basic block in the moduleOp
142 for (auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
144 return failure();
145 }
146 }
147
148 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
149 return success();
150}
151
153 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
158
159 binary.clear();
160 binary.reserve(moduleSize);
161
162 spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
163 nextID);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176}
177
178#ifndef NDEBUG
180 os << "\n= Value <id> Map =\n\n";
181 for (auto valueIDPair : valueIDMap) {
182 Value val = valueIDPair.first;
183 os << " " << val << " "
184 << "id = " << valueIDPair.second << ' ';
185 if (auto *op = val.getDefiningOp()) {
186 os << "from op '" << op->getName() << "'";
187 } else if (auto arg = dyn_cast<BlockArgument>(val)) {
188 Block *block = arg.getOwner();
189 os << "from argument of block " << block << ' ';
190 os << " in op '" << block->getParentOp()->getName() << "'";
191 }
192 os << '\n';
193 }
194}
195#endif
196
197//===----------------------------------------------------------------------===//
198// Module structure
199//===----------------------------------------------------------------------===//
200
201uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
202 auto funcID = funcIDMap.lookup(fnName);
203 if (!funcID) {
204 funcID = getNextID();
205 funcIDMap[fnName] = funcID;
206 }
207 return funcID;
208}
209
210void Serializer::processCapability() {
211 for (auto cap : module.getVceTriple()->getCapabilities())
212 encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
213 {static_cast<uint32_t>(cap)});
214}
215
216void Serializer::addLongCompositesCapability() {
217 if (longCompositesEmitted)
218 return;
219 longCompositesEmitted = true;
220 auto vceTriple = module.getVceTriple();
221 if (!llvm::is_contained(vceTriple->getCapabilities(),
222 spirv::Capability::LongCompositesINTEL))
224 capabilities, spirv::Opcode::OpCapability,
225 {static_cast<uint32_t>(spirv::Capability::LongCompositesINTEL)});
226 if (!llvm::is_contained(vceTriple->getExtensions(),
227 spirv::Extension::SPV_INTEL_long_composites)) {
228 SmallVector<uint32_t, 8> extName;
230 extName,
231 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
232 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
233 }
234}
235
236void Serializer::encodeInstructionWithContinuationInto(
237 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
238 ArrayRef<uint32_t> operands) {
239 if (1 + operands.size() <= spirv::kMaxWordCount) {
240 encodeInstructionInto(binary, op, operands);
241 return;
242 }
243
244 std::optional<spirv::Opcode> continuationOp =
246 assert(continuationOp && "op is not a splittable composite/struct opcode");
247
248 const unsigned chunk = spirv::kMaxWordCount - 1;
249 encodeInstructionInto(binary, op, operands.take_front(chunk));
250 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
251 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
252 encodeInstructionInto(binary, *continuationOp, rest.take_front(chunk));
253 }
254
255 addLongCompositesCapability();
256}
257
258void Serializer::processDebugInfo() {
259 if (!options.emitDebugInfo)
260 return;
261 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
262 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
263 fileID = getNextID();
264 SmallVector<uint32_t, 16> operands;
265 operands.push_back(fileID);
266 spirv::encodeStringLiteralInto(operands, fileName);
267 encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
268 // TODO: Encode more debug instructions.
269}
270
271LogicalResult Serializer::processExtension() {
272 llvm::SmallVector<uint32_t, 16> extName;
273 llvm::SmallSet<Extension, 4> deducedExts(
274 llvm::from_range, module.getVceTriple()->getExtensions());
275 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
276 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
277 TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
278 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
279 return module.emitError(
280 "SPV_KHR_non_semantic_info extension not available");
281 deducedExts.insert(nonSemanticInfoExt);
282 }
283 for (spirv::Extension ext : deducedExts) {
284 extName.clear();
285 spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
286 encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
287 }
288 return success();
289}
290
291void Serializer::processMemoryModel() {
292 StringAttr memoryModelName = module.getMemoryModelAttrName();
293 auto mm = static_cast<uint32_t>(
294 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
295 .getValue());
296
297 StringAttr addressingModelName = module.getAddressingModelAttrName();
298 auto am = static_cast<uint32_t>(
299 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
300 .getValue());
301
302 encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
303}
304
305static std::string getDecorationName(StringRef attrName) {
306 // convertToCamelFromSnakeCase will convert this to FpFastMathMode instead of
307 // expected FPFastMathMode.
308 if (attrName == "fp_fast_math_mode")
309 return "FPFastMathMode";
310 // similar here
311 if (attrName == "fp_rounding_mode")
312 return "FPRoundingMode";
313 // convertToCamelFromSnakeCase will not capitalize "INTEL".
314 if (attrName == "cache_control_load_intel")
315 return "CacheControlLoadINTEL";
316 if (attrName == "cache_control_store_intel")
317 return "CacheControlStoreINTEL";
318
319 return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true);
320}
321
322template <typename AttrTy, typename EmitF>
323static LogicalResult processDecorationList(Location loc, Decoration decoration,
324 Attribute attrList,
325 StringRef attrName, EmitF emitter) {
326 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
327 if (!arrayAttr) {
328 return emitError(loc, "expecting array attribute of ")
329 << attrName << " for " << stringifyDecoration(decoration);
330 }
331 if (arrayAttr.empty()) {
332 return emitError(loc, "expecting non-empty array attribute of ")
333 << attrName << " for " << stringifyDecoration(decoration);
334 }
335 for (Attribute attr : arrayAttr.getValue()) {
336 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
337 if (!cacheControlAttr) {
338 return emitError(loc, "expecting array attribute of ")
339 << attrName << " for " << stringifyDecoration(decoration);
340 }
341 // This named attribute encodes several decorations. Emit one per
342 // element in the array.
343 if (failed(emitter(cacheControlAttr)))
344 return failure();
345 }
346 return success();
347}
348
349LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
350 Decoration decoration,
351 Attribute attr) {
353 switch (decoration) {
354 case spirv::Decoration::LinkageAttributes: {
355 // Get the value of the Linkage Attributes
356 // e.g., LinkageAttributes=["linkageName", linkageType].
357 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
358 auto linkageName = linkageAttr.getLinkageName();
359 auto linkageType = linkageAttr.getLinkageType().getValue();
360 // Encode the Linkage Name (string literal to uint32_t).
361 spirv::encodeStringLiteralInto(args, linkageName);
362 // Encode LinkageType & Add the Linkagetype to the args.
363 args.push_back(static_cast<uint32_t>(linkageType));
364 break;
365 }
366 case spirv::Decoration::FPFastMathMode:
367 if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
368 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
369 break;
370 }
371 return emitError(loc, "expected FPFastMathModeAttr attribute for ")
372 << stringifyDecoration(decoration);
373 case spirv::Decoration::FPRoundingMode:
374 if (auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
375 args.push_back(static_cast<uint32_t>(intAttr.getValue()));
376 break;
377 }
378 return emitError(loc, "expected FPRoundingModeAttr attribute for ")
379 << stringifyDecoration(decoration);
380 case spirv::Decoration::Binding:
381 case spirv::Decoration::DescriptorSet:
382 case spirv::Decoration::Location:
383 case spirv::Decoration::Index:
384 case spirv::Decoration::Offset:
385 case spirv::Decoration::XfbBuffer:
386 case spirv::Decoration::XfbStride:
387 if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
388 args.push_back(intAttr.getValue().getZExtValue());
389 break;
390 }
391 return emitError(loc, "expected integer attribute for ")
392 << stringifyDecoration(decoration);
393 case spirv::Decoration::BuiltIn:
394 if (auto strAttr = dyn_cast<StringAttr>(attr)) {
395 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
396 if (enumVal) {
397 args.push_back(static_cast<uint32_t>(*enumVal));
398 break;
399 }
400 return emitError(loc, "invalid ")
401 << stringifyDecoration(decoration) << " decoration attribute "
402 << strAttr.getValue();
403 }
404 return emitError(loc, "expected string attribute for ")
405 << stringifyDecoration(decoration);
406 case spirv::Decoration::Aliased:
407 case spirv::Decoration::AliasedPointer:
408 case spirv::Decoration::Flat:
409 case spirv::Decoration::NonReadable:
410 case spirv::Decoration::NonWritable:
411 case spirv::Decoration::NoPerspective:
412 case spirv::Decoration::NoSignedWrap:
413 case spirv::Decoration::NoUnsignedWrap:
414 case spirv::Decoration::RelaxedPrecision:
415 case spirv::Decoration::Restrict:
416 case spirv::Decoration::RestrictPointer:
417 case spirv::Decoration::NoContraction:
418 case spirv::Decoration::Constant:
419 case spirv::Decoration::Block:
420 case spirv::Decoration::Invariant:
421 case spirv::Decoration::Patch:
422 case spirv::Decoration::Coherent:
423 // For unit attributes and decoration attributes, the args list
424 // has no values so we do nothing.
425 if (isa<UnitAttr, DecorationAttr>(attr))
426 break;
427 return emitError(loc,
428 "expected unit attribute or decoration attribute for ")
429 << stringifyDecoration(decoration);
430 case spirv::Decoration::CacheControlLoadINTEL:
432 loc, decoration, attr, "CacheControlLoadINTEL",
433 [&](CacheControlLoadINTELAttr attr) {
434 unsigned cacheLevel = attr.getCacheLevel();
435 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
436 return emitDecoration(
437 resultID, decoration,
438 {cacheLevel, static_cast<uint32_t>(loadCacheControl)});
439 });
440 case spirv::Decoration::CacheControlStoreINTEL:
442 loc, decoration, attr, "CacheControlStoreINTEL",
443 [&](CacheControlStoreINTELAttr attr) {
444 unsigned cacheLevel = attr.getCacheLevel();
445 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
446 return emitDecoration(
447 resultID, decoration,
448 {cacheLevel, static_cast<uint32_t>(storeCacheControl)});
449 });
450 case spirv::Decoration::AlignmentId:
451 case spirv::Decoration::MaxByteOffsetId:
452 case spirv::Decoration::CounterBuffer: {
453 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
454 if (!symRef)
455 return emitError(loc, "expected symbol reference for ")
456 << stringifyDecoration(decoration);
457 StringRef symName = symRef.getValue();
458 uint32_t operandID = getVariableID(symName);
459 if (!operandID)
460 operandID = getSpecConstID(symName);
461 if (!operandID)
462 return emitError(loc, "could not find <id> for symbol '")
463 << symName << "' referenced by "
464 << stringifyDecoration(decoration);
465 return emitDecorationId(resultID, decoration, {operandID});
466 }
467 default:
468 return emitError(loc, "unhandled decoration ")
469 << stringifyDecoration(decoration);
470 }
471 return emitDecoration(resultID, decoration, args);
472}
473
474LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
475 NamedAttribute attr) {
476 StringRef attrName = attr.getName().strref();
477 std::string decorationName = getDecorationName(attrName);
478 std::optional<Decoration> decoration =
479 spirv::symbolizeDecoration(decorationName);
480 if (!decoration) {
481 return emitError(
482 loc, "non-argument attributes expected to have snake-case-ified "
483 "decoration name, unhandled attribute with name : ")
484 << attrName;
485 }
486 return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
487}
488
489LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
490 assert(!name.empty() && "unexpected empty string for OpName");
491 if (!options.emitSymbolName)
492 return success();
493
494 SmallVector<uint32_t, 4> nameOperands;
495 nameOperands.push_back(resultID);
496 spirv::encodeStringLiteralInto(nameOperands, name);
497 encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
498 return success();
499}
500
501template <>
502LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
503 Location loc, spirv::ArrayType type, uint32_t resultID) {
504 if (unsigned stride = type.getArrayStride()) {
505 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
506 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
507 }
508 return success();
509}
510
511template <>
512LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
513 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
514 if (unsigned stride = type.getArrayStride()) {
515 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral
516 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
517 }
518 return success();
519}
520
521LogicalResult Serializer::processMemberDecoration(
522 uint32_t structID,
523 const spirv::StructType::MemberDecorationInfo &memberDecoration) {
525 {structID, memberDecoration.memberIndex,
526 static_cast<uint32_t>(memberDecoration.decoration)});
527 if (memberDecoration.hasValue()) {
528 args.push_back(
529 cast<IntegerAttr>(memberDecoration.decorationValue).getInt());
530 }
531 encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
532 return success();
533}
534
535//===----------------------------------------------------------------------===//
536// Type
537//===----------------------------------------------------------------------===//
538
539// According to the SPIR-V spec "Validation Rules for Shader Capabilities":
540// "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and
541// PushConstant Storage Classes must be explicitly laid out."
542bool Serializer::isInterfaceStructPtrType(Type type) const {
543 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
544 switch (ptrType.getStorageClass()) {
545 case spirv::StorageClass::PhysicalStorageBuffer:
546 case spirv::StorageClass::PushConstant:
547 case spirv::StorageClass::StorageBuffer:
548 case spirv::StorageClass::Uniform:
549 return isa<spirv::StructType>(ptrType.getPointeeType());
550 default:
551 break;
552 }
553 }
554 return false;
555}
556
557LogicalResult Serializer::processType(Location loc, Type type,
558 uint32_t &typeID) {
559 // Maintains a set of names for nested identified struct types. This is used
560 // to properly serialize recursive references.
561 SetVector<StringRef> serializationCtx;
562 return processTypeImpl(loc, type, typeID, serializationCtx);
563}
564
565LogicalResult
566Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
567 SetVector<StringRef> &serializationCtx) {
568
569 // Map unsigned integer types to singless integer types.
570 // This is needed otherwise the generated spirv assembly will contain
571 // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
572 // such module fails validation. Indeed at MLIR level the two types are
573 // different and lookup in the cache below misses.
574 // Note: This conversion needs to happen here before the type is looked up in
575 // the cache.
576 if (type.isUnsignedInteger()) {
577 type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
578 IntegerType::SignednessSemantics::Signless);
579 }
580
581 typeID = getTypeID(type);
582 if (typeID)
583 return success();
584
585 typeID = getNextID();
586 SmallVector<uint32_t, 4> operands;
587
588 operands.push_back(typeID);
589 auto typeEnum = spirv::Opcode::OpTypeVoid;
590 bool deferSerialization = false;
591
592 if ((isa<FunctionType>(type) &&
593 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
594 operands))) ||
595 (isa<GraphType>(type) &&
596 succeeded(
597 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
598 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
599 deferSerialization, serializationCtx))) {
600 if (deferSerialization)
601 return success();
602
603 typeIDMap[type] = typeID;
604
605 if (typeEnum == spirv::Opcode::OpTypeStruct)
606 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
607 operands);
608 else
609 encodeInstructionInto(typesGlobalValues, typeEnum, operands);
610
611 if (recursiveStructInfos.count(type) != 0) {
612 // This recursive struct type is emitted already, now the OpTypePointer
613 // instructions referring to recursive references are emitted as well.
614 for (auto &ptrInfo : recursiveStructInfos[type]) {
615 // TODO: This might not work if more than 1 recursive reference is
616 // present in the struct.
617 SmallVector<uint32_t, 4> ptrOperands;
618 ptrOperands.push_back(ptrInfo.pointerTypeID);
619 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
620 ptrOperands.push_back(typeIDMap[type]);
621
622 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
623 ptrOperands);
624 }
625
626 recursiveStructInfos[type].clear();
627 }
628
629 return success();
630 }
631
632 return emitError(loc, "failed to process type: ") << type;
633}
634
635LogicalResult Serializer::prepareBasicType(
636 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
637 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
638 SetVector<StringRef> &serializationCtx) {
639 deferSerialization = false;
640
641 if (isVoidType(type)) {
642 typeEnum = spirv::Opcode::OpTypeVoid;
643 return success();
644 }
645
646 if (auto intType = dyn_cast<IntegerType>(type)) {
647 if (intType.getWidth() == 1) {
648 typeEnum = spirv::Opcode::OpTypeBool;
649 return success();
650 }
651
652 typeEnum = spirv::Opcode::OpTypeInt;
653 operands.push_back(intType.getWidth());
654 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
655 // to preserve or validate.
656 // 0 indicates unsigned, or no signedness semantics
657 // 1 indicates signed semantics."
658 operands.push_back(intType.isSigned() ? 1 : 0);
659 return success();
660 }
661
662 if (auto floatType = dyn_cast<FloatType>(type)) {
663 typeEnum = spirv::Opcode::OpTypeFloat;
664 operands.push_back(floatType.getWidth());
665 if (floatType.isBF16()) {
666 operands.push_back(static_cast<uint32_t>(spirv::FPEncoding::BFloat16KHR));
667 }
668 if (floatType.isF8E4M3FN()) {
669 operands.push_back(
670 static_cast<uint32_t>(spirv::FPEncoding::Float8E4M3EXT));
671 }
672 if (floatType.isF8E5M2()) {
673 operands.push_back(
674 static_cast<uint32_t>(spirv::FPEncoding::Float8E5M2EXT));
675 }
676
677 return success();
678 }
679
680 if (auto vectorType = dyn_cast<VectorType>(type)) {
681 uint32_t elementTypeID = 0;
682 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
683 serializationCtx))) {
684 return failure();
685 }
686 typeEnum = spirv::Opcode::OpTypeVector;
687 operands.push_back(elementTypeID);
688 operands.push_back(vectorType.getNumElements());
689 return success();
690 }
691
692 if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
693 typeEnum = spirv::Opcode::OpTypeImage;
694 uint32_t sampledTypeID = 0;
695 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
696 return failure();
697
698 llvm::append_values(operands, sampledTypeID,
699 static_cast<uint32_t>(imageType.getDim()),
700 static_cast<uint32_t>(imageType.getDepthInfo()),
701 static_cast<uint32_t>(imageType.getArrayedInfo()),
702 static_cast<uint32_t>(imageType.getSamplingInfo()),
703 static_cast<uint32_t>(imageType.getSamplerUseInfo()),
704 static_cast<uint32_t>(imageType.getImageFormat()));
705 return success();
706 }
707
708 if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
709 typeEnum = spirv::Opcode::OpTypeArray;
710 uint32_t elementTypeID = 0;
711 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
712 serializationCtx))) {
713 return failure();
714 }
715 operands.push_back(elementTypeID);
716 if (auto elementCountID = prepareConstantInt(
717 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
718 operands.push_back(elementCountID);
719 }
720 return processTypeDecoration(loc, arrayType, resultID);
721 }
722
723 if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
724 uint32_t pointeeTypeID = 0;
725 spirv::StructType pointeeStruct =
726 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
727
728 if (pointeeStruct && pointeeStruct.isIdentified() &&
729 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
730 // A recursive reference to an enclosing struct is found.
731 //
732 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage
733 // class as operands.
734 SmallVector<uint32_t, 2> forwardPtrOperands;
735 forwardPtrOperands.push_back(resultID);
736 forwardPtrOperands.push_back(
737 static_cast<uint32_t>(ptrType.getStorageClass()));
738
739 encodeInstructionInto(typesGlobalValues,
740 spirv::Opcode::OpTypeForwardPointer,
741 forwardPtrOperands);
742
743 // 2. Find the pointee (enclosing) struct.
744 auto structType = spirv::StructType::getIdentified(
745 module.getContext(), pointeeStruct.getIdentifier());
746
747 if (!structType)
748 return failure();
749
750 // 3. Mark the OpTypePointer that is supposed to be emitted by this call
751 // as deferred.
752 deferSerialization = true;
753
754 // 4. Record the info needed to emit the deferred OpTypePointer
755 // instruction when the enclosing struct is completely serialized.
756 recursiveStructInfos[structType].push_back(
757 {resultID, ptrType.getStorageClass()});
758 } else {
759 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
760 serializationCtx)))
761 return failure();
762 }
763
764 typeEnum = spirv::Opcode::OpTypePointer;
765 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
766 operands.push_back(pointeeTypeID);
767
768 // TODO: Now struct decorations are supported this code may not be
769 // necessary. However, it is left to support backwards compatibility.
770 // Ideally, Block decorations should be inserted when converting to SPIR-V.
771 if (isInterfaceStructPtrType(ptrType)) {
772 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
773 if (!structType.hasDecoration(spirv::Decoration::Block))
774 if (failed(emitDecoration(getTypeID(pointeeStruct),
775 spirv::Decoration::Block)))
776 return emitError(loc, "cannot decorate ")
777 << pointeeStruct << " with Block decoration";
778 }
779
780 return success();
781 }
782
783 if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
784 uint32_t elementTypeID = 0;
785 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
786 elementTypeID, serializationCtx))) {
787 return failure();
788 }
789 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
790 operands.push_back(elementTypeID);
791 return processTypeDecoration(loc, runtimeArrayType, resultID);
792 }
793
794 if (isa<spirv::SamplerType>(type)) {
795 typeEnum = spirv::Opcode::OpTypeSampler;
796 return success();
797 }
798
799 if (isa<spirv::NamedBarrierType>(type)) {
800 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
801 return success();
802 }
803
804 if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
805 typeEnum = spirv::Opcode::OpTypeSampledImage;
806 uint32_t imageTypeID = 0;
807 if (failed(
808 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
809 return failure();
810 }
811 operands.push_back(imageTypeID);
812 return success();
813 }
814
815 if (auto structType = dyn_cast<spirv::StructType>(type)) {
816 if (structType.isIdentified()) {
817 if (failed(processName(resultID, structType.getIdentifier())))
818 return failure();
819 serializationCtx.insert(structType.getIdentifier());
820 }
821
822 bool hasOffset = structType.hasOffset();
823 for (auto elementIndex :
824 llvm::seq<uint32_t>(0, structType.getNumElements())) {
825 uint32_t elementTypeID = 0;
826 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
827 elementTypeID, serializationCtx))) {
828 return failure();
829 }
830 operands.push_back(elementTypeID);
831 if (hasOffset) {
832 auto intType = IntegerType::get(structType.getContext(), 32);
833 // Decorate each struct member with an offset
834 spirv::StructType::MemberDecorationInfo offsetDecoration{
835 elementIndex, spirv::Decoration::Offset,
836 IntegerAttr::get(intType,
837 structType.getMemberOffset(elementIndex))};
838 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
839 return emitError(loc, "cannot decorate ")
840 << elementIndex << "-th member of " << structType
841 << " with its offset";
842 }
843 }
844 }
845 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
846 structType.getMemberDecorations(memberDecorations);
847
848 for (auto &memberDecoration : memberDecorations) {
849 if (failed(processMemberDecoration(resultID, memberDecoration))) {
850 return emitError(loc, "cannot decorate ")
851 << static_cast<uint32_t>(memberDecoration.memberIndex)
852 << "-th member of " << structType << " with "
853 << stringifyDecoration(memberDecoration.decoration);
854 }
855 }
856
857 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
858 structType.getStructDecorations(structDecorations);
859
860 for (spirv::StructType::StructDecorationInfo &structDecoration :
861 structDecorations) {
862 if (failed(processDecorationAttr(loc, resultID,
863 structDecoration.decoration,
864 structDecoration.decorationValue))) {
865 return emitError(loc, "cannot decorate struct ")
866 << structType << " with "
867 << stringifyDecoration(structDecoration.decoration);
868 }
869 }
870
871 typeEnum = spirv::Opcode::OpTypeStruct;
872
873 if (structType.isIdentified())
874 serializationCtx.remove(structType.getIdentifier());
875
876 return success();
877 }
878
879 if (auto cooperativeMatrixType =
880 dyn_cast<spirv::CooperativeMatrixType>(type)) {
881 uint32_t elementTypeID = 0;
882 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
883 elementTypeID, serializationCtx))) {
884 return failure();
885 }
886 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
887 auto getConstantOp = [&](uint32_t id) {
888 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
889 return prepareConstantInt(loc, attr);
890 };
891 llvm::append_values(
892 operands, elementTypeID,
893 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
894 getConstantOp(cooperativeMatrixType.getRows()),
895 getConstantOp(cooperativeMatrixType.getColumns()),
896 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
897 return success();
898 }
899
900 if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
901 uint32_t elementTypeID = 0;
902 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
903 serializationCtx))) {
904 return failure();
905 }
906 typeEnum = spirv::Opcode::OpTypeMatrix;
907 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
908 return success();
909 }
910
911 if (auto tensorArmType = dyn_cast<TensorArmType>(type)) {
912 uint32_t elementTypeID = 0;
913 uint32_t rank = 0;
914 uint32_t shapeID = 0;
915 uint32_t rankID = 0;
916 if (failed(processTypeImpl(loc, tensorArmType.getElementType(),
917 elementTypeID, serializationCtx))) {
918 return failure();
919 }
920 if (tensorArmType.hasRank()) {
921 ArrayRef<int64_t> dims = tensorArmType.getShape();
922 rank = dims.size();
923 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
924 if (rankID == 0) {
925 return failure();
926 }
927
928 bool shaped = llvm::all_of(dims, [](const auto &dim) { return dim > 0; });
929 if (rank > 0 && shaped) {
930 auto I32Type = IntegerType::get(type.getContext(), 32);
931 auto shapeType = ArrayType::get(I32Type, rank);
932 if (rank == 1) {
933 SmallVector<uint64_t, 1> index(rank);
934 shapeID = prepareDenseElementsConstant(
935 loc, shapeType,
936 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
937 index);
938 } else {
939 shapeID = prepareArrayConstant(
940 loc, shapeType,
941 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
942 }
943 if (shapeID == 0) {
944 return failure();
945 }
946 }
947 }
948 typeEnum = spirv::Opcode::OpTypeTensorARM;
949 operands.push_back(elementTypeID);
950 if (rankID == 0)
951 return success();
952 operands.push_back(rankID);
953 if (shapeID == 0)
954 return success();
955 operands.push_back(shapeID);
956 return success();
957 }
958
959 // TODO: Handle other types.
960 return emitError(loc, "unhandled type in serialization: ") << type;
961}
962
963LogicalResult
964Serializer::prepareFunctionType(Location loc, FunctionType type,
965 spirv::Opcode &typeEnum,
966 SmallVectorImpl<uint32_t> &operands) {
967 typeEnum = spirv::Opcode::OpTypeFunction;
968 assert(type.getNumResults() <= 1 &&
969 "serialization supports only a single return value");
970 uint32_t resultID = 0;
971 if (failed(processType(
972 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
973 resultID))) {
974 return failure();
975 }
976 operands.push_back(resultID);
977 for (auto &res : type.getInputs()) {
978 uint32_t argTypeID = 0;
979 if (failed(processType(loc, res, argTypeID))) {
980 return failure();
981 }
982 operands.push_back(argTypeID);
983 }
984 return success();
985}
986
987LogicalResult
988Serializer::prepareGraphType(Location loc, GraphType type,
989 spirv::Opcode &typeEnum,
990 SmallVectorImpl<uint32_t> &operands) {
991 typeEnum = spirv::Opcode::OpTypeGraphARM;
992 assert(type.getNumResults() >= 1 &&
993 "serialization requires at least a return value");
994
995 operands.push_back(type.getNumInputs());
996
997 for (Type argType : type.getInputs()) {
998 uint32_t argTypeID = 0;
999 if (failed(processType(loc, argType, argTypeID)))
1000 return failure();
1001 operands.push_back(argTypeID);
1002 }
1003
1004 for (Type resType : type.getResults()) {
1005 uint32_t resTypeID = 0;
1006 if (failed(processType(loc, resType, resTypeID)))
1007 return failure();
1008 operands.push_back(resTypeID);
1009 }
1010
1011 return success();
1012}
1013
1014//===----------------------------------------------------------------------===//
1015// Constant
1016//===----------------------------------------------------------------------===//
1017
1018uint32_t Serializer::prepareConstant(Location loc, Type constType,
1019 Attribute valueAttr) {
1020 if (auto id = prepareConstantScalar(loc, valueAttr)) {
1021 return id;
1022 }
1023
1024 // This is a composite literal. We need to handle each component separately
1025 // and then emit an OpConstantComposite for the whole.
1026
1027 if (auto id = getConstantID(valueAttr)) {
1028 return id;
1029 }
1030
1031 uint32_t typeID = 0;
1032 if (failed(processType(loc, constType, typeID))) {
1033 return 0;
1034 }
1035
1036 uint32_t resultID = 0;
1037 if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1038 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1039 SmallVector<uint64_t, 4> index(rank);
1040 resultID = prepareDenseElementsConstant(loc, constType, attr,
1041 /*dim=*/0, index);
1042 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1043 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1044 }
1045
1046 if (resultID == 0) {
1047 emitError(loc, "cannot serialize attribute: ") << valueAttr;
1048 return 0;
1049 }
1050
1051 constIDMap[valueAttr] = resultID;
1052 return resultID;
1053}
1054
1055uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1056 ArrayAttr attr) {
1057 uint32_t typeID = 0;
1058 if (failed(processType(loc, constType, typeID))) {
1059 return 0;
1060 }
1061
1062 uint32_t resultID = getNextID();
1063 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1064 operands.reserve(attr.size() + 2);
1065 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
1066 for (Attribute elementAttr : attr) {
1067 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
1068 operands.push_back(elementID);
1069 } else {
1070 return 0;
1071 }
1072 }
1073 encodeInstructionWithContinuationInto(
1074 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1075
1076 return resultID;
1077}
1078
1079// TODO: Turn the below function into iterative function, instead of
1080// recursive function.
1081uint32_t
1082Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1083 DenseElementsAttr valueAttr, int dim,
1084 MutableArrayRef<uint64_t> index) {
1085 auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
1086 assert(dim <= shapedType.getRank());
1087 if (shapedType.getRank() == dim) {
1088 if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1089 return attr.getType().getElementType().isInteger(1)
1090 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1091 : prepareConstantInt(loc,
1092 attr.getValues<IntegerAttr>()[index]);
1093 }
1094 if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1095 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1096 }
1097 return 0;
1098 }
1099
1100 uint32_t typeID = 0;
1101 if (failed(processType(loc, constType, typeID))) {
1102 return 0;
1103 }
1104
1105 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1106 uint32_t resultID = getNextID();
1107 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1108 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1109 if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1110 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1111 if (!innerShape.empty())
1112 elementType = spirv::TensorArmType::get(innerShape, elementType);
1113 }
1114
1115 // "If the Result Type is a cooperative matrix type, then there must be only
1116 // one Constituent, with scalar type matching the cooperative matrix Component
1117 // Type, and all components of the matrix are initialized to that value."
1118 // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
1119 if (isa<spirv::CooperativeMatrixType>(constType)) {
1120 if (!valueAttr.isSplat()) {
1121 emitError(
1122 loc,
1123 "cannot serialize a non-splat value for a cooperative matrix type");
1124 return 0;
1125 }
1126 // numberOfConstituents is 1, so we only need one more elements in the
1127 // SmallVector, so the total is 3 (1 + 2).
1128 operands.reserve(3);
1129 // We set dim directly to `shapedType.getRank()` so the recursive call
1130 // directly returns the scalar type.
1131 if (auto elementID = prepareDenseElementsConstant(
1132 loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
1133 operands.push_back(elementID);
1134 } else {
1135 return 0;
1136 }
1137 } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
1138 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1139 {typeID, resultID});
1140 return resultID;
1141 } else {
1142 operands.reserve(numberOfConstituents + 2);
1143 for (int i = 0; i < numberOfConstituents; ++i) {
1144 index[dim] = i;
1145 if (auto elementID = prepareDenseElementsConstant(
1146 loc, elementType, valueAttr, dim + 1, index)) {
1147 operands.push_back(elementID);
1148 } else {
1149 return 0;
1150 }
1151 }
1152 }
1153 encodeInstructionWithContinuationInto(
1154 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1155
1156 return resultID;
1157}
1158
1159uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1160 bool isSpec) {
1161 if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1162 return prepareConstantFp(loc, floatAttr, isSpec);
1163 }
1164 if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1165 return prepareConstantBool(loc, boolAttr, isSpec);
1166 }
1167 if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1168 return prepareConstantInt(loc, intAttr, isSpec);
1169 }
1170
1171 return 0;
1172}
1173
1174uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1175 bool isSpec) {
1176 if (!isSpec) {
1177 // We can de-duplicate normal constants, but not specialization constants.
1178 if (auto id = getConstantID(boolAttr)) {
1179 return id;
1180 }
1181 }
1182
1183 // Process the type for this bool literal
1184 uint32_t typeID = 0;
1185 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
1186 return 0;
1187 }
1188
1189 auto resultID = getNextID();
1190 auto opcode = boolAttr.getValue()
1191 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1192 : spirv::Opcode::OpConstantTrue)
1193 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1194 : spirv::Opcode::OpConstantFalse);
1195 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
1196
1197 if (!isSpec) {
1198 constIDMap[boolAttr] = resultID;
1199 }
1200 return resultID;
1201}
1202
1203uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1204 bool isSpec) {
1205 if (!isSpec) {
1206 // We can de-duplicate normal constants, but not specialization constants.
1207 if (auto id = getConstantID(intAttr)) {
1208 return id;
1209 }
1210 }
1211
1212 // Process the type for this integer literal
1213 uint32_t typeID = 0;
1214 if (failed(processType(loc, intAttr.getType(), typeID))) {
1215 return 0;
1216 }
1217
1218 auto resultID = getNextID();
1219 APInt value = intAttr.getValue();
1220 unsigned bitwidth = value.getBitWidth();
1221 bool isSigned = intAttr.getType().isSignedInteger();
1222 auto opcode =
1223 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1224
1225 switch (bitwidth) {
1226 // According to SPIR-V spec, "When the type's bit width is less than
1227 // 32-bits, the literal's value appears in the low-order bits of the word,
1228 // and the high-order bits must be 0 for a floating-point type, or 0 for an
1229 // integer type with Signedness of 0, or sign extended when Signedness
1230 // is 1."
1231 case 32:
1232 case 16:
1233 case 8: {
1234 uint32_t word = 0;
1235 if (isSigned) {
1236 word = static_cast<int32_t>(value.getSExtValue());
1237 } else {
1238 word = static_cast<uint32_t>(value.getZExtValue());
1239 }
1240 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1241 } break;
1242 // According to SPIR-V spec: "When the type's bit width is larger than one
1243 // word, the literal’s low-order words appear first."
1244 case 64: {
1245 struct DoubleWord {
1246 uint32_t word1;
1247 uint32_t word2;
1248 } words;
1249 if (isSigned) {
1250 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1251 } else {
1252 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1253 }
1254 encodeInstructionInto(typesGlobalValues, opcode,
1255 {typeID, resultID, words.word1, words.word2});
1256 } break;
1257 default: {
1258 std::string valueStr;
1259 llvm::raw_string_ostream rss(valueStr);
1260 value.print(rss, /*isSigned=*/false);
1261
1262 emitError(loc, "cannot serialize ")
1263 << bitwidth << "-bit integer literal: " << valueStr;
1264 return 0;
1265 }
1266 }
1267
1268 if (!isSpec) {
1269 constIDMap[intAttr] = resultID;
1270 }
1271 return resultID;
1272}
1273
1274uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1275 IntegerAttr intAttr) {
1276 // De-duplicate graph constants.
1277 if (uint32_t id = getGraphConstantARMId(intAttr)) {
1278 return id;
1279 }
1280
1281 // Process the type for this graph constant.
1282 uint32_t typeID = 0;
1283 if (failed(processType(loc, graphConstType, typeID))) {
1284 return 0;
1285 }
1286
1287 uint32_t resultID = getNextID();
1288 APInt value = intAttr.getValue();
1289 unsigned bitwidth = value.getBitWidth();
1290 if (bitwidth > 32) {
1291 emitError(loc, "Too wide attribute for OpGraphConstantARM: ")
1292 << bitwidth << " bits";
1293 return 0;
1294 }
1295 bool isSigned = value.isSignedIntN(bitwidth);
1296
1297 uint32_t word = 0;
1298 if (isSigned) {
1299 word = static_cast<int32_t>(value.getSExtValue());
1300 } else {
1301 word = static_cast<uint32_t>(value.getZExtValue());
1302 }
1303 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpGraphConstantARM,
1304 {typeID, resultID, word});
1305 graphConstIDMap[intAttr] = resultID;
1306 return resultID;
1307}
1308
1309uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1310 bool isSpec) {
1311 if (!isSpec) {
1312 // We can de-duplicate normal constants, but not specialization constants.
1313 if (auto id = getConstantID(floatAttr)) {
1314 return id;
1315 }
1316 }
1317
1318 // Process the type for this float literal
1319 uint32_t typeID = 0;
1320 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1321 return 0;
1322 }
1323
1324 auto resultID = getNextID();
1325 APFloat value = floatAttr.getValue();
1326 const llvm::fltSemantics *semantics = &value.getSemantics();
1327
1328 auto opcode =
1329 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1330
1331 if (semantics == &APFloat::IEEEsingle()) {
1332 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1333 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1334 } else if (semantics == &APFloat::IEEEdouble()) {
1335 struct DoubleWord {
1336 uint32_t word1;
1337 uint32_t word2;
1338 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1339 encodeInstructionInto(typesGlobalValues, opcode,
1340 {typeID, resultID, words.word1, words.word2});
1341 } else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1342 &APFloat::Float8E4M3FN(),
1343 &APFloat::Float8E5M2()},
1344 semantics)) {
1345 uint32_t word =
1346 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
1347 encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
1348 } else {
1349 std::string valueStr;
1350 llvm::raw_string_ostream rss(valueStr);
1351 value.print(rss);
1352
1353 emitError(loc, "cannot serialize ")
1354 << floatAttr.getType() << "-typed float literal: " << valueStr;
1355 return 0;
1356 }
1357
1358 if (!isSpec) {
1359 constIDMap[floatAttr] = resultID;
1360 }
1361 return resultID;
1362}
1363
1364// Returns type of attribute. In case of a TypedAttr this will simply return
1365// the type. But for an ArrayAttr which is untyped and can be multidimensional
1366// it creates the ArrayType recursively.
1368 if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1369 return typedAttr.getType();
1370 }
1371
1372 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1373 return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
1374 }
1375
1376 return nullptr;
1377}
1378
1379uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
1380 Type resultType,
1381 Attribute valueAttr) {
1382 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1383 if (uint32_t id = getConstantCompositeReplicateID(valueTypePair)) {
1384 return id;
1385 }
1386
1387 uint32_t typeID = 0;
1388 if (failed(processType(loc, resultType, typeID))) {
1389 return 0;
1390 }
1391
1392 Type valueType = getValueType(valueAttr);
1393 if (!valueAttr)
1394 return 0;
1395
1396 auto compositeType = dyn_cast<CompositeType>(resultType);
1397 if (!compositeType)
1398 return 0;
1399 Type elementType = compositeType.getElementType(0);
1400
1401 uint32_t constandID;
1402 if (elementType == valueType) {
1403 constandID = prepareConstant(loc, elementType, valueAttr);
1404 } else {
1405 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1406 }
1407
1408 uint32_t resultID = getNextID();
1409 if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
1410 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
1411 {typeID, resultID});
1412 } else {
1413 encodeInstructionInto(typesGlobalValues,
1414 spirv::Opcode::OpConstantCompositeReplicateEXT,
1415 {typeID, resultID, constandID});
1416 }
1417
1418 constCompositeReplicateIDMap[valueTypePair] = resultID;
1419 return resultID;
1420}
1421
1422//===----------------------------------------------------------------------===//
1423// Control flow
1424//===----------------------------------------------------------------------===//
1425
1426uint32_t Serializer::getOrCreateBlockID(Block *block) {
1427 if (uint32_t id = getBlockID(block))
1428 return id;
1429 return blockIDMap[block] = getNextID();
1430}
1431
1432#ifndef NDEBUG
1433void Serializer::printBlock(Block *block, raw_ostream &os) {
1434 os << "block " << block << " (id = ";
1435 if (uint32_t id = getBlockID(block))
1436 os << id;
1437 else
1438 os << "unknown";
1439 os << ")\n";
1440}
1441#endif
1442
1443LogicalResult
1444Serializer::processBlock(Block *block, bool omitLabel,
1445 function_ref<LogicalResult()> emitMerge) {
1446 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
1447 LLVM_DEBUG(block->print(llvm::dbgs()));
1448 LLVM_DEBUG(llvm::dbgs() << '\n');
1449 if (!omitLabel) {
1450 uint32_t blockID = getOrCreateBlockID(block);
1451 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1452
1453 // Emit OpLabel for this block.
1454 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1455 }
1456
1457 // Emit OpPhi instructions for block arguments, if any.
1458 if (failed(emitPhiForBlockArguments(block)))
1459 return failure();
1460
1461 // If we need to emit merge instructions, it must happen in this block. Check
1462 // whether we have other structured control flow ops, which will be expanded
1463 // into multiple basic blocks. If that's the case, we need to emit the merge
1464 // right now and then create new blocks for further serialization of the ops
1465 // in this block.
1466 if (emitMerge &&
1467 llvm::any_of(block->getOperations(),
1468 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1469 if (failed(emitMerge()))
1470 return failure();
1471 emitMerge = nullptr;
1472
1473 // Start a new block for further serialization.
1474 uint32_t blockID = getNextID();
1475 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
1476 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
1477 }
1478
1479 // Process each op in this block except the terminator.
1480 for (Operation &op : llvm::drop_end(*block)) {
1481 if (failed(processOperation(&op)))
1482 return failure();
1483 }
1484
1485 // Process the terminator.
1486 if (emitMerge)
1487 if (failed(emitMerge()))
1488 return failure();
1489 if (failed(processOperation(&block->back())))
1490 return failure();
1491
1492 return success();
1493}
1494
1495LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
1496 // Nothing to do if this block has no arguments or it's the entry block, which
1497 // always has the same arguments as the function signature.
1498 if (block->args_empty() || block->isEntryBlock())
1499 return success();
1500
1501 LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
1502
1503 // If the block has arguments, we need to create SPIR-V OpPhi instructions.
1504 // A SPIR-V OpPhi instruction is of the syntax:
1505 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair
1506 // So we need to collect all predecessor blocks and the arguments they send
1507 // to this block.
1508 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1509 for (Block *mlirPredecessor : block->getPredecessors()) {
1510 auto *terminator = mlirPredecessor->getTerminator();
1511 LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
1512 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1513 LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
1514 // The predecessor here is the immediate one according to MLIR's IR
1515 // structure. It does not directly map to the incoming parent block for the
1516 // OpPhi instructions at SPIR-V binary level. This is because structured
1517 // control flow ops are serialized to multiple SPIR-V blocks. If there is a
1518 // spirv.mlir.selection/spirv.mlir.loop op in the MLIR predecessor block,
1519 // the branch op jumping to the OpPhi's block then resides in the previous
1520 // structured control flow op's merge block.
1521 Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
1522 LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
1523 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1524 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1525 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1526 } else if (auto branchCondOp =
1527 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1528 std::optional<OperandRange> blockOperands;
1529 if (branchCondOp.getTrueTarget() == block) {
1530 blockOperands = branchCondOp.getTrueTargetOperands();
1531 } else {
1532 assert(branchCondOp.getFalseTarget() == block);
1533 blockOperands = branchCondOp.getFalseTargetOperands();
1534 }
1535 assert(!blockOperands->empty() &&
1536 "expected non-empty block operand range");
1537 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1538 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1539 std::optional<OperandRange> blockOperands;
1540 if (block == switchOp.getDefaultTarget()) {
1541 blockOperands = switchOp.getDefaultOperands();
1542 } else {
1543 SuccessorRange targets = switchOp.getTargets();
1544 auto it = llvm::find(targets, block);
1545 assert(it != targets.end());
1546 size_t index = std::distance(targets.begin(), it);
1547 blockOperands = switchOp.getTargetOperands(index);
1548 }
1549 assert(!blockOperands->empty() &&
1550 "expected non-empty block operand range");
1551 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1552 } else {
1553 return terminator->emitError("unimplemented terminator for Phi creation");
1554 }
1555 LLVM_DEBUG({
1556 llvm::dbgs() << " block arguments:\n";
1557 for (Value v : predecessors.back().second)
1558 llvm::dbgs() << " " << v << "\n";
1559 });
1560 }
1561
1562 // Then create OpPhi instruction for each of the block argument.
1563 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
1564 BlockArgument arg = block->getArgument(argIndex);
1565
1566 // Get the type <id> and result <id> for this OpPhi instruction.
1567 uint32_t phiTypeID = 0;
1568 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
1569 return failure();
1570 uint32_t phiID = getNextID();
1571
1572 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
1573 << arg << " (id = " << phiID << ")\n");
1574
1575 // Prepare the (value <id>, parent block <id>) pairs.
1576 SmallVector<uint32_t, 8> phiArgs;
1577 phiArgs.push_back(phiTypeID);
1578 phiArgs.push_back(phiID);
1579
1580 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1581 Value value = predecessors[predIndex].second[argIndex];
1582 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1583 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
1584 << ") value " << value << ' ');
1585 // Each pair is a value <id> ...
1586 uint32_t valueId = getValueID(value);
1587 if (valueId == 0) {
1588 // The op generating this value hasn't been visited yet so we don't have
1589 // an <id> assigned yet. Record this to fix up later.
1590 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
1591 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1592 phiArgs.size());
1593 } else {
1594 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
1595 }
1596 phiArgs.push_back(valueId);
1597 // ... and a parent block <id>.
1598 phiArgs.push_back(predBlockId);
1599 }
1600
1601 encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
1602 valueIDMap[arg] = phiID;
1603 }
1604
1605 return success();
1606}
1607
1608//===----------------------------------------------------------------------===//
1609// Operation
1610//===----------------------------------------------------------------------===//
1611
1612LogicalResult Serializer::encodeExtensionInstruction(
1613 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1614 ArrayRef<uint32_t> operands) {
1615 // Check if the extension has been imported.
1616 auto &setID = extendedInstSetIDMap[extensionSetName];
1617 if (!setID) {
1618 setID = getNextID();
1619 SmallVector<uint32_t, 16> importOperands;
1620 importOperands.push_back(setID);
1621 spirv::encodeStringLiteralInto(importOperands, extensionSetName);
1622 encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
1623 importOperands);
1624 }
1625
1626 // The first two operands are the result type <id> and result <id>. The set
1627 // <id> and the opcode need to be insert after this.
1628 if (operands.size() < 2) {
1629 return op->emitError("extended instructions must have a result encoding");
1630 }
1631 SmallVector<uint32_t, 8> extInstOperands;
1632 extInstOperands.reserve(operands.size() + 2);
1633 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1634 extInstOperands.push_back(setID);
1635 extInstOperands.push_back(extensionOpcode);
1636 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1637 encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
1638 extInstOperands);
1639 return success();
1640}
1641
1642LogicalResult Serializer::processOperation(Operation *opInst) {
1643 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
1644
1645 // First dispatch the ops that do not directly mirror an instruction from
1646 // the SPIR-V spec.
1648 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
1649 .Case([&](spirv::BranchOp op) { return processBranchOp(op); })
1650 .Case([&](spirv::BranchConditionalOp op) {
1651 return processBranchConditionalOp(op);
1652 })
1653 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
1654 .Case([&](spirv::CompositeConstructOp op) {
1655 return processCompositeConstructOp(op);
1656 })
1657 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1658 return processConstantCompositeReplicateOp(op);
1659 })
1660 .Case([&](spirv::FuncOp op) { return processFuncOp(op); })
1661 .Case([&](spirv::GraphARMOp op) { return processGraphARMOp(op); })
1662 .Case([&](spirv::GraphEntryPointARMOp op) {
1663 return processGraphEntryPointARMOp(op);
1664 })
1665 .Case([&](spirv::GraphOutputsARMOp op) {
1666 return processGraphOutputsARMOp(op);
1667 })
1668 .Case([&](spirv::GlobalVariableOp op) {
1669 return processGlobalVariableOp(op);
1670 })
1671 .Case([&](spirv::GraphConstantARMOp op) {
1672 return processGraphConstantARMOp(op);
1673 })
1674 .Case([&](spirv::LoopOp op) { return processLoopOp(op); })
1675 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
1676 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
1677 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
1678 .Case([&](spirv::SpecConstantCompositeOp op) {
1679 return processSpecConstantCompositeOp(op);
1680 })
1681 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1682 return processSpecConstantCompositeReplicateOp(op);
1683 })
1684 .Case([&](spirv::SpecConstantOperationOp op) {
1685 return processSpecConstantOperationOp(op);
1686 })
1687 .Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
1688 .Case([&](spirv::UndefOp op) { return processUndefOp(op); })
1689 .Case([&](spirv::VariableOp op) { return processVariableOp(op); })
1690
1691 // Then handle all the ops that directly mirror SPIR-V instructions with
1692 // auto-generated methods.
1693 .Default(
1694 [&](Operation *op) { return dispatchToAutogenSerialization(op); });
1695}
1696
1697LogicalResult
1698Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1699 Location loc = op.getLoc();
1700
1701 uint32_t resultTypeID = 0;
1702 if (failed(processType(loc, op.getType(), resultTypeID)))
1703 return failure();
1704
1705 uint32_t resultID = getNextID();
1706 valueIDMap[op.getResult()] = resultID;
1707
1708 SmallVector<uint32_t, 8> operands;
1709 operands.reserve(2 + op.getConstituents().size());
1710 operands.push_back(resultTypeID);
1711 operands.push_back(resultID);
1712 for (Value constituent : op.getConstituents()) {
1713 uint32_t id = getValueID(constituent);
1714 assert(id && "use before def!");
1715 operands.push_back(id);
1716 }
1717
1718 if (failed(emitDebugLine(functionBody, loc)))
1719 return failure();
1720
1721 encodeInstructionWithContinuationInto(
1722 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1723
1724 for (auto attr : op->getAttrs()) {
1725 if (failed(processDecoration(loc, resultID, attr)))
1726 return failure();
1727 }
1728
1729 return success();
1730}
1731
1732LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1733 StringRef extInstSet,
1734 uint32_t opcode) {
1735 SmallVector<uint32_t, 4> operands;
1736 Location loc = op->getLoc();
1737
1738 uint32_t resultID = 0;
1739 if (op->getNumResults() != 0) {
1740 uint32_t resultTypeID = 0;
1741 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
1742 return failure();
1743 operands.push_back(resultTypeID);
1744
1745 resultID = getNextID();
1746 operands.push_back(resultID);
1747 valueIDMap[op->getResult(0)] = resultID;
1748 };
1749
1750 for (Value operand : op->getOperands())
1751 operands.push_back(getValueID(operand));
1752
1753 if (failed(emitDebugLine(functionBody, loc)))
1754 return failure();
1755
1756 if (extInstSet.empty()) {
1757 encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
1758 operands);
1759 } else {
1760 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1761 return failure();
1762 }
1763
1764 if (op->getNumResults() != 0) {
1765 for (auto attr : op->getAttrs()) {
1766 if (failed(processDecoration(loc, resultID, attr)))
1767 return failure();
1768 }
1769 }
1770
1771 return success();
1772}
1773
1774LogicalResult Serializer::emitDecoration(uint32_t target,
1775 spirv::Decoration decoration,
1776 ArrayRef<uint32_t> params) {
1777 uint32_t wordCount = 3 + params.size();
1778 llvm::append_values(
1779 decorations,
1780 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
1781 static_cast<uint32_t>(decoration));
1782 llvm::append_range(decorations, params);
1783 return success();
1784}
1785
1786LogicalResult Serializer::emitDecorationId(uint32_t target,
1787 spirv::Decoration decoration,
1788 ArrayRef<uint32_t> operandIds) {
1789 uint32_t wordCount = 3 + operandIds.size();
1790 llvm::append_values(
1791 decorations,
1792 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorateId), target,
1793 static_cast<uint32_t>(decoration));
1794 llvm::append_range(decorations, operandIds);
1795 return success();
1796}
1797
1798LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1799 Location loc) {
1800 if (!options.emitDebugInfo)
1801 return success();
1802
1803 if (lastProcessedWasMergeInst) {
1804 lastProcessedWasMergeInst = false;
1805 return success();
1806 }
1807
1808 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1809 if (fileLoc)
1810 encodeInstructionInto(binary, spirv::Opcode::OpLine,
1811 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
1812 return success();
1813}
1814} // namespace spirv
1815} // namespace mlir
return success()
ArrayAttr()
b getContext())
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Definition Value.h:321
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< pred_iterator > getPredecessors()
Definition Block.h:250
OpListType & getOperations()
Definition Block.h:147
Operation & back()
Definition Block.h:162
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
Definition Block.h:146
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:537
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:115
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:429
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147