21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/StringExtras.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/ADT/bit.h"
26#include "llvm/Support/Debug.h"
30#define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
102 for (
Operation &op : llvm::drop_begin(ops))
103 if (
auto funcOp = dyn_cast<spirv::FuncOp>(op))
104 if (funcOp.getBody().empty())
115 uint32_t wordCount = 1 + operands.size();
117 binary.append(operands.begin(), operands.end());
122 : module(module), mlirBuilder(module.
getContext()), options(options) {}
125 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
127 if (failed(module.verifyInvariants()))
132 if (failed(processExtension())) {
135 processMemoryModel();
142 for (
auto &op : *module.getBody()) {
143 if (failed(processOperation(&op))) {
148 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
154 extensions.size() + extendedSets.size() +
155 memoryModel.size() + entryPoints.size() +
156 executionModes.size() + decorations.size() +
157 typesGlobalValues.size() + functions.size() + graphs.size();
160 binary.reserve(moduleSize);
164 binary.append(capabilities.begin(), capabilities.end());
165 binary.append(extensions.begin(), extensions.end());
166 binary.append(extendedSets.begin(), extendedSets.end());
167 binary.append(memoryModel.begin(), memoryModel.end());
168 binary.append(entryPoints.begin(), entryPoints.end());
169 binary.append(executionModes.begin(), executionModes.end());
170 binary.append(debug.begin(), debug.end());
171 binary.append(names.begin(), names.end());
172 binary.append(decorations.begin(), decorations.end());
173 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
174 binary.append(functions.begin(), functions.end());
175 binary.append(graphs.begin(), graphs.end());
176 binary.append(graphsDebugInfo.begin(), graphsDebugInfo.end());
181 os <<
"\n= Value <id> Map =\n\n";
182 for (
auto valueIDPair : valueIDMap) {
183 Value val = valueIDPair.first;
184 os <<
" " << val <<
" "
185 <<
"id = " << valueIDPair.second <<
' ';
187 os <<
"from op '" << op->getName() <<
"'";
188 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
189 Block *block = arg.getOwner();
190 os <<
"from argument of block " << block <<
' ';
202uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
203 auto funcID = funcIDMap.lookup(fnName);
205 funcID = getNextID();
206 funcIDMap[fnName] = funcID;
211void Serializer::processCapability() {
212 for (
auto cap : module.getVceTriple()->getCapabilities())
214 {
static_cast<uint32_t
>(cap)});
217void Serializer::addLongCompositesCapability() {
218 if (longCompositesEmitted)
220 longCompositesEmitted =
true;
221 auto vceTriple =
module.getVceTriple();
222 if (!llvm::is_contained(vceTriple->getCapabilities(),
223 spirv::Capability::LongCompositesINTEL))
225 capabilities, spirv::Opcode::OpCapability,
226 {
static_cast<uint32_t
>(spirv::Capability::LongCompositesINTEL)});
227 if (!llvm::is_contained(vceTriple->getExtensions(),
228 spirv::Extension::SPV_INTEL_long_composites)) {
229 SmallVector<uint32_t, 8> extName;
232 spirv::stringifyExtension(spirv::Extension::SPV_INTEL_long_composites));
237void Serializer::encodeInstructionWithContinuationInto(
238 SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
239 ArrayRef<uint32_t> operands) {
245 std::optional<spirv::Opcode> continuationOp =
247 assert(continuationOp &&
"op is not a splittable composite/struct opcode");
251 for (ArrayRef<uint32_t> rest = operands.drop_front(chunk); !rest.empty();
252 rest = rest.drop_front(std::min<size_t>(rest.size(), chunk))) {
256 addLongCompositesCapability();
259void Serializer::processDebugInfo() {
260 if (!options.emitDebugInfo)
262 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
263 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
264 fileID = getNextID();
265 SmallVector<uint32_t, 16> operands;
266 operands.push_back(fileID);
272LogicalResult Serializer::processExtension() {
273 llvm::SmallVector<uint32_t, 16> extName;
274 llvm::SmallSet<Extension, 4> deducedExts(
275 llvm::from_range, module.getVceTriple()->getExtensions());
276 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
277 if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
279 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
280 return module.emitError(
281 "SPV_KHR_non_semantic_info extension not available");
282 deducedExts.insert(nonSemanticInfoExt);
284 for (spirv::Extension ext : deducedExts) {
292void Serializer::processMemoryModel() {
293 StringAttr memoryModelName =
module.getMemoryModelAttrName();
294 auto mm =
static_cast<uint32_t
>(
295 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
298 StringAttr addressingModelName =
module.getAddressingModelAttrName();
299 auto am =
static_cast<uint32_t
>(
300 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
309 if (attrName ==
"fp_fast_math_mode")
310 return "FPFastMathMode";
312 if (attrName ==
"fp_rounding_mode")
313 return "FPRoundingMode";
315 if (attrName ==
"cache_control_load_intel")
316 return "CacheControlLoadINTEL";
317 if (attrName ==
"cache_control_store_intel")
318 return "CacheControlStoreINTEL";
320 return llvm::convertToCamelFromSnakeCase(attrName,
true);
323template <
typename AttrTy,
typename EmitF>
326 StringRef attrName, EmitF emitter) {
327 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
329 return emitError(loc,
"expecting array attribute of ")
330 << attrName <<
" for " << stringifyDecoration(decoration);
332 if (arrayAttr.empty()) {
333 return emitError(loc,
"expecting non-empty array attribute of ")
334 << attrName <<
" for " << stringifyDecoration(decoration);
336 for (
Attribute attr : arrayAttr.getValue()) {
337 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
338 if (!cacheControlAttr) {
339 return emitError(loc,
"expecting array attribute of ")
340 << attrName <<
" for " << stringifyDecoration(decoration);
344 if (failed(emitter(cacheControlAttr)))
350LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
351 Decoration decoration,
354 switch (decoration) {
355 case spirv::Decoration::LinkageAttributes: {
358 auto linkageAttr = dyn_cast<spirv::LinkageAttributesAttr>(attr);
359 auto linkageName = linkageAttr.getLinkageName();
360 auto linkageType = linkageAttr.getLinkageType().getValue();
364 args.push_back(
static_cast<uint32_t
>(linkageType));
367 case spirv::Decoration::FPFastMathMode:
368 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
369 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
372 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
373 << stringifyDecoration(decoration);
374 case spirv::Decoration::FPRoundingMode:
375 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
376 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
379 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
380 << stringifyDecoration(decoration);
381 case spirv::Decoration::Binding:
382 case spirv::Decoration::DescriptorSet:
383 case spirv::Decoration::Location:
384 case spirv::Decoration::Index:
385 case spirv::Decoration::Offset:
386 case spirv::Decoration::XfbBuffer:
387 case spirv::Decoration::XfbStride:
388 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
389 args.push_back(intAttr.getValue().getZExtValue());
392 return emitError(loc,
"expected integer attribute for ")
393 << stringifyDecoration(decoration);
394 case spirv::Decoration::BuiltIn:
395 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
396 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
398 args.push_back(
static_cast<uint32_t
>(*enumVal));
402 << stringifyDecoration(decoration) <<
" decoration attribute "
403 << strAttr.getValue();
405 return emitError(loc,
"expected string attribute for ")
406 << stringifyDecoration(decoration);
407 case spirv::Decoration::Aliased:
408 case spirv::Decoration::AliasedPointer:
409 case spirv::Decoration::Flat:
410 case spirv::Decoration::NonReadable:
411 case spirv::Decoration::NonWritable:
412 case spirv::Decoration::NoPerspective:
413 case spirv::Decoration::NoSignedWrap:
414 case spirv::Decoration::NoUnsignedWrap:
415 case spirv::Decoration::RelaxedPrecision:
416 case spirv::Decoration::Restrict:
417 case spirv::Decoration::RestrictPointer:
418 case spirv::Decoration::NoContraction:
419 case spirv::Decoration::Constant:
420 case spirv::Decoration::Block:
421 case spirv::Decoration::BufferBlock:
422 case spirv::Decoration::Invariant:
423 case spirv::Decoration::Patch:
424 case spirv::Decoration::Coherent:
427 if (isa<UnitAttr, DecorationAttr>(attr))
430 "expected unit attribute or decoration attribute for ")
431 << stringifyDecoration(decoration);
432 case spirv::Decoration::CacheControlLoadINTEL:
434 loc, decoration, attr,
"CacheControlLoadINTEL",
435 [&](CacheControlLoadINTELAttr attr) {
436 unsigned cacheLevel = attr.getCacheLevel();
437 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
438 return emitDecoration(
439 resultID, decoration,
440 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
442 case spirv::Decoration::CacheControlStoreINTEL:
444 loc, decoration, attr,
"CacheControlStoreINTEL",
445 [&](CacheControlStoreINTELAttr attr) {
446 unsigned cacheLevel = attr.getCacheLevel();
447 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
448 return emitDecoration(
449 resultID, decoration,
450 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
452 case spirv::Decoration::AlignmentId:
453 case spirv::Decoration::MaxByteOffsetId:
454 case spirv::Decoration::CounterBuffer: {
455 auto symRef = dyn_cast<FlatSymbolRefAttr>(attr);
457 return emitError(loc,
"expected symbol reference for ")
458 << stringifyDecoration(decoration);
459 StringRef symName = symRef.getValue();
460 uint32_t operandID = getVariableID(symName);
462 operandID = getSpecConstID(symName);
464 return emitError(loc,
"could not find <id> for symbol '")
465 << symName <<
"' referenced by "
466 << stringifyDecoration(decoration);
467 return emitDecorationId(resultID, decoration, {operandID});
470 return emitError(loc,
"unhandled decoration ")
471 << stringifyDecoration(decoration);
473 return emitDecoration(resultID, decoration, args);
476LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
477 NamedAttribute attr) {
478 StringRef attrName = attr.
getName().strref();
480 std::optional<Decoration> decoration =
481 spirv::symbolizeDecoration(decorationName);
484 loc,
"non-argument attributes expected to have snake-case-ified "
485 "decoration name, unhandled attribute with name : ")
488 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
491LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
492 assert(!name.empty() &&
"unexpected empty string for OpName");
493 if (!options.emitSymbolName)
496 SmallVector<uint32_t, 4> nameOperands;
497 nameOperands.push_back(resultID);
504LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
508 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
514LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
518 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
523LogicalResult Serializer::processMemberDecoration(
528 static_cast<uint32_t
>(memberDecoration.
decoration)});
544bool Serializer::isInterfaceStructPtrType(Type type)
const {
545 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
546 switch (ptrType.getStorageClass()) {
547 case spirv::StorageClass::PhysicalStorageBuffer:
548 case spirv::StorageClass::PushConstant:
549 case spirv::StorageClass::StorageBuffer:
550 case spirv::StorageClass::Uniform:
551 return isa<spirv::StructType>(ptrType.getPointeeType());
559LogicalResult Serializer::processType(Location loc, Type type,
564 return processTypeImpl(loc, type, typeID, serializationCtx);
568Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
580 IntegerType::SignednessSemantics::Signless);
583 typeID = getTypeID(type);
587 typeID = getNextID();
588 SmallVector<uint32_t, 4> operands;
590 operands.push_back(typeID);
591 auto typeEnum = spirv::Opcode::OpTypeVoid;
592 bool deferSerialization =
false;
594 if ((isa<FunctionType>(type) &&
595 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
597 (isa<GraphType>(type) &&
599 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
600 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
601 deferSerialization, serializationCtx))) {
602 if (deferSerialization)
605 typeIDMap[type] = typeID;
607 if (typeEnum == spirv::Opcode::OpTypeStruct)
608 encodeInstructionWithContinuationInto(typesGlobalValues, typeEnum,
613 if (recursiveStructInfos.count(type) != 0) {
616 for (
auto &ptrInfo : recursiveStructInfos[type]) {
619 SmallVector<uint32_t, 4> ptrOperands;
620 ptrOperands.push_back(ptrInfo.pointerTypeID);
621 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
622 ptrOperands.push_back(typeIDMap[type]);
628 recursiveStructInfos[type].clear();
634 return emitError(loc,
"failed to process type: ") << type;
637LogicalResult Serializer::prepareBasicType(
638 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
639 SmallVectorImpl<uint32_t> &operands,
bool &deferSerialization,
641 deferSerialization =
false;
643 if (isVoidType(type)) {
644 typeEnum = spirv::Opcode::OpTypeVoid;
648 if (
auto intType = dyn_cast<IntegerType>(type)) {
649 if (intType.getWidth() == 1) {
650 typeEnum = spirv::Opcode::OpTypeBool;
654 typeEnum = spirv::Opcode::OpTypeInt;
655 operands.push_back(intType.getWidth());
660 operands.push_back(intType.isSigned() ? 1 : 0);
664 if (
auto floatType = dyn_cast<FloatType>(type)) {
665 typeEnum = spirv::Opcode::OpTypeFloat;
666 operands.push_back(floatType.getWidth());
667 if (floatType.isBF16()) {
668 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
670 if (floatType.isF8E4M3FN()) {
672 static_cast<uint32_t
>(spirv::FPEncoding::Float8E4M3EXT));
674 if (floatType.isF8E5M2()) {
676 static_cast<uint32_t
>(spirv::FPEncoding::Float8E5M2EXT));
682 if (
auto vectorType = dyn_cast<VectorType>(type)) {
683 uint32_t elementTypeID = 0;
684 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
685 serializationCtx))) {
688 typeEnum = spirv::Opcode::OpTypeVector;
689 operands.push_back(elementTypeID);
690 operands.push_back(vectorType.getNumElements());
694 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
695 typeEnum = spirv::Opcode::OpTypeImage;
696 uint32_t sampledTypeID = 0;
697 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
700 llvm::append_values(operands, sampledTypeID,
701 static_cast<uint32_t
>(imageType.getDim()),
702 static_cast<uint32_t
>(imageType.getDepthInfo()),
703 static_cast<uint32_t
>(imageType.getArrayedInfo()),
704 static_cast<uint32_t
>(imageType.getSamplingInfo()),
705 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
706 static_cast<uint32_t
>(imageType.getImageFormat()));
710 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
711 typeEnum = spirv::Opcode::OpTypeArray;
712 uint32_t elementTypeID = 0;
713 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
714 serializationCtx))) {
717 operands.push_back(elementTypeID);
718 if (
auto elementCountID = prepareConstantInt(
719 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
720 operands.push_back(elementCountID);
722 return processTypeDecoration(loc, arrayType, resultID);
725 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
726 uint32_t pointeeTypeID = 0;
727 spirv::StructType pointeeStruct =
728 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
731 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
736 SmallVector<uint32_t, 2> forwardPtrOperands;
737 forwardPtrOperands.push_back(resultID);
738 forwardPtrOperands.push_back(
739 static_cast<uint32_t
>(ptrType.getStorageClass()));
742 spirv::Opcode::OpTypeForwardPointer,
754 deferSerialization =
true;
758 recursiveStructInfos[structType].push_back(
759 {resultID, ptrType.getStorageClass()});
761 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
766 typeEnum = spirv::Opcode::OpTypePointer;
767 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
768 operands.push_back(pointeeTypeID);
773 if (isInterfaceStructPtrType(ptrType)) {
774 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
775 if (!structType.hasDecoration(spirv::Decoration::Block) &&
776 !structType.hasDecoration(spirv::Decoration::BufferBlock))
777 if (
failed(emitDecoration(getTypeID(pointeeStruct),
778 spirv::Decoration::Block)))
779 return emitError(loc,
"cannot decorate ")
780 << pointeeStruct <<
" with Block decoration";
786 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
787 uint32_t elementTypeID = 0;
788 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
789 elementTypeID, serializationCtx))) {
792 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
793 operands.push_back(elementTypeID);
794 return processTypeDecoration(loc, runtimeArrayType, resultID);
797 if (isa<spirv::SamplerType>(type)) {
798 typeEnum = spirv::Opcode::OpTypeSampler;
802 if (isa<spirv::NamedBarrierType>(type)) {
803 typeEnum = spirv::Opcode::OpTypeNamedBarrier;
807 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
808 typeEnum = spirv::Opcode::OpTypeSampledImage;
809 uint32_t imageTypeID = 0;
811 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
814 operands.push_back(imageTypeID);
818 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
819 if (structType.isIdentified()) {
820 if (
failed(processName(resultID, structType.getIdentifier())))
822 serializationCtx.insert(structType.getIdentifier());
825 bool hasOffset = structType.hasOffset();
826 for (
auto elementIndex :
827 llvm::seq<uint32_t>(0, structType.getNumElements())) {
828 uint32_t elementTypeID = 0;
829 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
830 elementTypeID, serializationCtx))) {
833 operands.push_back(elementTypeID);
835 auto intType = IntegerType::get(structType.getContext(), 32);
837 spirv::StructType::MemberDecorationInfo offsetDecoration{
838 elementIndex, spirv::Decoration::Offset,
839 IntegerAttr::get(intType,
840 structType.getMemberOffset(elementIndex))};
841 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
842 return emitError(loc,
"cannot decorate ")
843 << elementIndex <<
"-th member of " << structType
844 <<
" with its offset";
848 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
849 structType.getMemberDecorations(memberDecorations);
851 for (
auto &memberDecoration : memberDecorations) {
852 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
853 return emitError(loc,
"cannot decorate ")
854 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
855 <<
"-th member of " << structType <<
" with "
856 << stringifyDecoration(memberDecoration.
decoration);
860 SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
861 structType.getStructDecorations(structDecorations);
863 for (spirv::StructType::StructDecorationInfo &structDecoration :
865 if (
failed(processDecorationAttr(loc, resultID,
866 structDecoration.decoration,
867 structDecoration.decorationValue))) {
868 return emitError(loc,
"cannot decorate struct ")
869 << structType <<
" with "
870 << stringifyDecoration(structDecoration.decoration);
874 typeEnum = spirv::Opcode::OpTypeStruct;
876 if (structType.isIdentified())
877 serializationCtx.remove(structType.getIdentifier());
882 if (
auto cooperativeMatrixType =
883 dyn_cast<spirv::CooperativeMatrixType>(type)) {
884 uint32_t elementTypeID = 0;
885 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
886 elementTypeID, serializationCtx))) {
889 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
890 auto getConstantOp = [&](uint32_t id) {
891 auto attr = IntegerAttr::get(IntegerType::get(type.
getContext(), 32),
id);
892 return prepareConstantInt(loc, attr);
895 operands, elementTypeID,
896 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
897 getConstantOp(cooperativeMatrixType.getRows()),
898 getConstantOp(cooperativeMatrixType.getColumns()),
899 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
903 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
904 uint32_t elementTypeID = 0;
905 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
906 serializationCtx))) {
909 typeEnum = spirv::Opcode::OpTypeMatrix;
910 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
914 if (
auto tensorArmType = dyn_cast<TensorArmType>(type)) {
915 uint32_t elementTypeID = 0;
917 uint32_t shapeID = 0;
919 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
920 elementTypeID, serializationCtx))) {
923 if (tensorArmType.hasRank()) {
924 ArrayRef<int64_t> dims = tensorArmType.getShape();
926 rankID = prepareConstantInt(loc, mlirBuilder.getI32IntegerAttr(rank));
931 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
932 if (rank > 0 && shaped) {
933 auto I32Type = IntegerType::get(type.
getContext(), 32);
936 SmallVector<uint64_t, 1> index(rank);
937 shapeID = prepareDenseElementsConstant(
939 mlirBuilder.getI32TensorAttr(SmallVector<int32_t>(dims)), 0,
942 shapeID = prepareArrayConstant(
944 mlirBuilder.getI32ArrayAttr(SmallVector<int32_t>(dims)));
951 typeEnum = spirv::Opcode::OpTypeTensorARM;
952 operands.push_back(elementTypeID);
955 operands.push_back(rankID);
958 operands.push_back(shapeID);
963 return emitError(loc,
"unhandled type in serialization: ") << type;
967Serializer::prepareFunctionType(Location loc, FunctionType type,
968 spirv::Opcode &typeEnum,
969 SmallVectorImpl<uint32_t> &operands) {
970 typeEnum = spirv::Opcode::OpTypeFunction;
971 assert(type.getNumResults() <= 1 &&
972 "serialization supports only a single return value");
973 uint32_t resultID = 0;
975 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
979 operands.push_back(resultID);
980 for (
auto &res : type.getInputs()) {
981 uint32_t argTypeID = 0;
982 if (
failed(processType(loc, res, argTypeID))) {
985 operands.push_back(argTypeID);
991Serializer::prepareGraphType(Location loc, GraphType type,
992 spirv::Opcode &typeEnum,
993 SmallVectorImpl<uint32_t> &operands) {
994 typeEnum = spirv::Opcode::OpTypeGraphARM;
995 assert(type.getNumResults() >= 1 &&
996 "serialization requires at least a return value");
998 operands.push_back(type.getNumInputs());
1000 for (Type argType : type.getInputs()) {
1001 uint32_t argTypeID = 0;
1002 if (
failed(processType(loc, argType, argTypeID)))
1004 operands.push_back(argTypeID);
1007 for (Type resType : type.getResults()) {
1008 uint32_t resTypeID = 0;
1009 if (
failed(processType(loc, resType, resTypeID)))
1011 operands.push_back(resTypeID);
1021uint32_t Serializer::prepareConstant(Location loc, Type constType,
1022 Attribute valueAttr) {
1023 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
1030 if (
auto id = getConstantID(valueAttr)) {
1034 uint32_t typeID = 0;
1035 if (
failed(processType(loc, constType, typeID))) {
1039 uint32_t resultID = 0;
1040 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
1041 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
1042 SmallVector<uint64_t, 4> index(rank);
1043 resultID = prepareDenseElementsConstant(loc, constType, attr,
1045 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
1046 resultID = prepareArrayConstant(loc, constType, arrayAttr);
1049 if (resultID == 0) {
1050 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
1054 constIDMap[valueAttr] = resultID;
1058uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
1060 uint32_t typeID = 0;
1061 if (
failed(processType(loc, constType, typeID))) {
1065 uint32_t resultID = getNextID();
1066 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1067 operands.reserve(attr.size() + 2);
1068 spirv::CompositeType compositeType = cast<spirv::CompositeType>(constType);
1069 for (
auto [idx, elementAttr] : llvm::enumerate(attr)) {
1070 if (uint32_t elementID = prepareConstant(
1072 operands.push_back(elementID);
1077 encodeInstructionWithContinuationInto(
1078 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1086Serializer::prepareDenseElementsConstant(Location loc, Type constType,
1087 DenseElementsAttr valueAttr,
int dim,
1088 MutableArrayRef<uint64_t> index) {
1089 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
1090 assert(dim <= shapedType.getRank());
1091 if (shapedType.getRank() == dim) {
1092 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
1093 return attr.getType().getElementType().isInteger(1)
1094 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
1095 : prepareConstantInt(loc,
1096 attr.getValues<IntegerAttr>()[index]);
1098 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
1099 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
1104 uint32_t typeID = 0;
1105 if (
failed(processType(loc, constType, typeID))) {
1109 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1110 uint32_t resultID = getNextID();
1111 SmallVector<uint32_t, 4> operands = {typeID, resultID};
1112 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1113 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1114 ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
1115 if (!innerShape.empty())
1123 if (isa<spirv::CooperativeMatrixType>(constType)) {
1127 "cannot serialize a non-splat value for a cooperative matrix type");
1132 operands.reserve(3);
1135 if (
auto elementID = prepareDenseElementsConstant(
1136 loc, elementType, valueAttr, shapedType.getRank(), index)) {
1137 operands.push_back(elementID);
1141 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1143 {typeID, resultID});
1146 operands.reserve(numberOfConstituents + 2);
1147 for (
int i = 0; i < numberOfConstituents; ++i) {
1149 if (
auto elementID = prepareDenseElementsConstant(
1150 loc, elementType, valueAttr, dim + 1, index)) {
1151 operands.push_back(elementID);
1157 encodeInstructionWithContinuationInto(
1158 typesGlobalValues, spirv::Opcode::OpConstantComposite, operands);
1163uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
1165 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1166 return prepareConstantFp(loc, floatAttr, isSpec);
1168 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1169 return prepareConstantBool(loc, boolAttr, isSpec);
1171 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1172 return prepareConstantInt(loc, intAttr, isSpec);
1178uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
1182 if (
auto id = getConstantID(boolAttr)) {
1188 uint32_t typeID = 0;
1189 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1193 auto resultID = getNextID();
1195 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1196 : spirv::Opcode::OpConstantTrue)
1197 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1198 : spirv::Opcode::OpConstantFalse);
1202 constIDMap[boolAttr] = resultID;
1207uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
1211 if (
auto id = getConstantID(intAttr)) {
1217 uint32_t typeID = 0;
1218 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1222 auto resultID = getNextID();
1223 APInt value = intAttr.getValue();
1224 unsigned bitwidth = value.getBitWidth();
1225 bool isSigned = intAttr.getType().isSignedInteger();
1227 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1240 word =
static_cast<int32_t
>(value.getSExtValue());
1242 word =
static_cast<uint32_t
>(value.getZExtValue());
1254 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1256 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1259 {typeID, resultID, words.word1, words.word2});
1262 std::string valueStr;
1263 llvm::raw_string_ostream rss(valueStr);
1264 value.print(rss,
false);
1267 << bitwidth <<
"-bit integer literal: " << valueStr;
1273 constIDMap[intAttr] = resultID;
1278uint32_t Serializer::prepareGraphConstantId(Location loc, Type graphConstType,
1279 IntegerAttr intAttr) {
1281 if (uint32_t
id = getGraphConstantARMId(intAttr)) {
1286 uint32_t typeID = 0;
1287 if (
failed(processType(loc, graphConstType, typeID))) {
1291 uint32_t resultID = getNextID();
1292 APInt value = intAttr.getValue();
1293 unsigned bitwidth = value.getBitWidth();
1294 if (bitwidth > 32) {
1295 emitError(loc,
"Too wide attribute for OpGraphConstantARM: ")
1296 << bitwidth <<
" bits";
1299 bool isSigned = value.isSignedIntN(bitwidth);
1303 word =
static_cast<int32_t
>(value.getSExtValue());
1305 word =
static_cast<uint32_t
>(value.getZExtValue());
1308 {typeID, resultID, word});
1309 graphConstIDMap[intAttr] = resultID;
1313uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
1317 if (
auto id = getConstantID(floatAttr)) {
1323 uint32_t typeID = 0;
1324 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1328 auto resultID = getNextID();
1329 APFloat value = floatAttr.getValue();
1330 const llvm::fltSemantics *semantics = &value.getSemantics();
1333 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1335 if (semantics == &APFloat::IEEEsingle()) {
1336 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1338 }
else if (semantics == &APFloat::IEEEdouble()) {
1342 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1344 {typeID, resultID, words.word1, words.word2});
1345 }
else if (llvm::is_contained({&APFloat::IEEEhalf(), &APFloat::BFloat(),
1346 &APFloat::Float8E4M3FN(),
1347 &APFloat::Float8E5M2()},
1350 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1353 std::string valueStr;
1354 llvm::raw_string_ostream rss(valueStr);
1358 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1363 constIDMap[floatAttr] = resultID;
1372 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1373 return typedAttr.getType();
1376 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1383uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1386 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1387 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1391 uint32_t typeID = 0;
1392 if (
failed(processType(loc, resultType, typeID))) {
1400 auto compositeType = dyn_cast<CompositeType>(resultType);
1405 uint32_t constandID;
1406 if (elementType == valueType) {
1407 constandID = prepareConstant(loc, elementType, valueAttr);
1409 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1412 uint32_t resultID = getNextID();
1413 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1415 {typeID, resultID});
1418 spirv::Opcode::OpConstantCompositeReplicateEXT,
1419 {typeID, resultID, constandID});
1422 constCompositeReplicateIDMap[valueTypePair] = resultID;
1430uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1431 if (uint32_t
id = getBlockID(block))
1433 return blockIDMap[block] = getNextID();
1437void Serializer::printBlock(
Block *block, raw_ostream &os) {
1438 os <<
"block " << block <<
" (id = ";
1439 if (uint32_t
id = getBlockID(block))
1448Serializer::processBlock(
Block *block,
bool omitLabel,
1450 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1451 LLVM_DEBUG(block->
print(llvm::dbgs()));
1452 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1454 uint32_t blockID = getOrCreateBlockID(block);
1455 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1462 if (
failed(emitPhiForBlockArguments(block)))
1472 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1475 emitMerge =
nullptr;
1478 uint32_t blockID = getNextID();
1484 for (Operation &op : llvm::drop_end(*block)) {
1485 if (
failed(processOperation(&op)))
1493 if (
failed(processOperation(&block->
back())))
1499LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1505 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1512 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
1514 auto *terminator = mlirPredecessor->getTerminator();
1515 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1516 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1517 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1526 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1527 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1528 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1529 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1530 }
else if (
auto branchCondOp =
1531 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1532 std::optional<OperandRange> blockOperands;
1533 if (branchCondOp.getTrueTarget() == block) {
1534 blockOperands = branchCondOp.getTrueTargetOperands();
1536 assert(branchCondOp.getFalseTarget() == block);
1537 blockOperands = branchCondOp.getFalseTargetOperands();
1539 assert(!blockOperands->empty() &&
1540 "expected non-empty block operand range");
1541 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1542 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1543 std::optional<OperandRange> blockOperands;
1544 if (block == switchOp.getDefaultTarget()) {
1545 blockOperands = switchOp.getDefaultOperands();
1547 SuccessorRange targets = switchOp.getTargets();
1548 auto it = llvm::find(targets, block);
1549 assert(it != targets.end());
1550 size_t index = std::distance(targets.begin(), it);
1551 blockOperands = switchOp.getTargetOperands(index);
1553 assert(!blockOperands->empty() &&
1554 "expected non-empty block operand range");
1555 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1557 return terminator->emitError(
"unimplemented terminator for Phi creation");
1560 llvm::dbgs() <<
" block arguments:\n";
1561 for (Value v : predecessors.back().second)
1562 llvm::dbgs() <<
" " << v <<
"\n";
1567 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1571 uint32_t phiTypeID = 0;
1574 uint32_t phiID = getNextID();
1576 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1577 << arg <<
" (id = " << phiID <<
")\n");
1580 SmallVector<uint32_t, 8> phiArgs;
1581 phiArgs.push_back(phiTypeID);
1582 phiArgs.push_back(phiID);
1584 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1585 Value value = predecessors[predIndex].second[argIndex];
1586 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1587 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1588 <<
") value " << value <<
' ');
1590 uint32_t valueId = getValueID(value);
1594 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1595 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1598 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1600 phiArgs.push_back(valueId);
1602 phiArgs.push_back(predBlockId);
1606 valueIDMap[arg] = phiID;
1616LogicalResult Serializer::encodeExtensionInstruction(
1617 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1618 ArrayRef<uint32_t> operands, SmallVectorImpl<uint32_t> &binary) {
1620 auto &setID = extendedInstSetIDMap[extensionSetName];
1622 setID = getNextID();
1623 SmallVector<uint32_t, 16> importOperands;
1624 importOperands.push_back(setID);
1632 if (operands.size() < 2) {
1633 return op->
emitError(
"extended instructions must have a result encoding");
1635 SmallVector<uint32_t, 8> extInstOperands;
1636 extInstOperands.reserve(operands.size() + 2);
1637 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1638 extInstOperands.push_back(setID);
1639 extInstOperands.push_back(extensionOpcode);
1640 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1645LogicalResult Serializer::encodeExtensionInstruction(
1646 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1647 ArrayRef<uint32_t> operands) {
1648 if (
failed(encodeExtensionInstruction(op, extensionSetName, extensionOpcode,
1649 operands, functionBody)))
1652 if (extensionSetName ==
extTosa)
1653 updateTosaOpsMap(op);
1658LogicalResult Serializer::processOperation(Operation *opInst) {
1659 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1664 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1665 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1666 .Case([&](spirv::BranchConditionalOp op) {
1667 return processBranchConditionalOp(op);
1669 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1670 .Case([&](spirv::CompositeConstructOp op) {
1671 return processCompositeConstructOp(op);
1673 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1674 return processConstantCompositeReplicateOp(op);
1676 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1677 .Case([&](spirv::GraphARMOp op) {
return processGraphARMOp(op); })
1678 .Case([&](spirv::GraphEntryPointARMOp op) {
1679 return processGraphEntryPointARMOp(op);
1681 .Case([&](spirv::GraphOutputsARMOp op) {
1682 return processGraphOutputsARMOp(op);
1684 .Case([&](spirv::GlobalVariableOp op) {
1685 return processGlobalVariableOp(op);
1687 .Case([&](spirv::GraphConstantARMOp op) {
1688 return processGraphConstantARMOp(op);
1690 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1691 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1692 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1693 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1694 .Case([&](spirv::SpecConstantCompositeOp op) {
1695 return processSpecConstantCompositeOp(op);
1697 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1698 return processSpecConstantCompositeReplicateOp(op);
1700 .Case([&](spirv::SpecConstantOperationOp op) {
1701 return processSpecConstantOperationOp(op);
1703 .Case([&](spirv::SwitchOp op) {
return processSwitchOp(op); })
1704 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1705 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1710 [&](Operation *op) {
return dispatchToAutogenSerialization(op); });
1714Serializer::processCompositeConstructOp(spirv::CompositeConstructOp op) {
1715 Location loc = op.getLoc();
1717 uint32_t resultTypeID = 0;
1718 if (
failed(processType(loc, op.getType(), resultTypeID)))
1721 uint32_t resultID = getNextID();
1722 valueIDMap[op.getResult()] = resultID;
1724 SmallVector<uint32_t, 8> operands;
1725 operands.reserve(2 + op.getConstituents().size());
1726 operands.push_back(resultTypeID);
1727 operands.push_back(resultID);
1728 for (Value constituent : op.getConstituents()) {
1729 uint32_t
id = getValueID(constituent);
1730 assert(
id &&
"use before def!");
1731 operands.push_back(
id);
1734 if (
failed(emitDebugLine(functionBody, loc)))
1737 encodeInstructionWithContinuationInto(
1738 functionBody, spirv::Opcode::OpCompositeConstruct, operands);
1740 for (
auto attr : op->getAttrs()) {
1741 if (
failed(processDecoration(loc, resultID, attr)))
1748LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
1749 StringRef extInstSet,
1751 SmallVector<uint32_t, 4> operands;
1752 Location loc = op->
getLoc();
1754 uint32_t resultID = 0;
1756 uint32_t resultTypeID = 0;
1759 operands.push_back(resultTypeID);
1761 resultID = getNextID();
1762 operands.push_back(resultID);
1763 valueIDMap[op->
getResult(0)] = resultID;
1767 operands.push_back(getValueID(operand));
1771 if (
failed(emitDebugLine(functionBody, loc)))
1774 if (extInstSet.empty()) {
1778 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1784 if (
failed(processDecoration(loc, resultID, attr)))
1792void Serializer::updateTosaOpsMap(Operation *op) {
1793 if (!options.emitDebugInfo)
1796 if (
auto graphOp = dyn_cast<spirv::GraphARMOp>(op->
getParentOp())) {
1797 if (uint32_t graphID = getFunctionID(graphOp.getName()))
1798 tosaOpsMap[graphID][op->
getLoc()].insert(op);
1802LogicalResult Serializer::emitDecoration(uint32_t
target,
1803 spirv::Decoration decoration,
1804 ArrayRef<uint32_t> params) {
1805 uint32_t wordCount = 3 + params.size();
1806 llvm::append_values(
1809 static_cast<uint32_t
>(decoration));
1810 llvm::append_range(decorations, params);
1814LogicalResult Serializer::emitDecorationId(uint32_t
target,
1815 spirv::Decoration decoration,
1816 ArrayRef<uint32_t> operandIds) {
1817 uint32_t wordCount = 3 + operandIds.size();
1818 llvm::append_values(
1821 static_cast<uint32_t
>(decoration));
1822 llvm::append_range(decorations, operandIds);
1826LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
1828 if (!options.emitDebugInfo)
1831 if (lastProcessedWasMergeInst) {
1832 lastProcessedWasMergeInst =
false;
1836 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1839 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
static bool isZeroValue(Attribute attr)
static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp)
Move all functions declaration before functions definitions.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
llvm::iplist< Operation > OpListType
This is the list of operations in the block.
bool getValue() const
Return the boolean value of this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
void moveBefore(Operation *existingOp)
Unlink this operation from its current block and insert it right before existingOp which may be in th...
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
Type getElementType(unsigned) const
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
std::optional< spirv::Opcode > getContinuationOpcode(spirv::Opcode parent)
Returns the SPV_INTEL_long_composites continuation opcode that may follow parent, or std::nullopt if ...
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr uint32_t kMaxWordCount
Max number of words https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
constexpr llvm::StringLiteral extTosa
Extension set name for TOSA ops.
static LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
llvm::TypeSwitch< T, ResultT > TypeSwitch
llvm::function_ref< Fn > function_ref
Attribute decorationValue