21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
99 uint32_t wordCount = 1 + operands.size();
101 binary.append(operands.begin(), operands.end());
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
111 if (
failed(module.verifyInvariants()))
116 if (
failed(processExtension())) {
119 processMemoryModel();
124 for (
auto &op : *module.getBody()) {
125 if (
failed(processOperation(&op))) {
130 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
136 extensions.size() + extendedSets.size() +
137 memoryModel.size() + entryPoints.size() +
138 executionModes.size() + decorations.size() +
139 typesGlobalValues.size() + functions.size();
142 binary.reserve(moduleSize);
146 binary.append(capabilities.begin(), capabilities.end());
147 binary.append(extensions.begin(), extensions.end());
148 binary.append(extendedSets.begin(), extendedSets.end());
149 binary.append(memoryModel.begin(), memoryModel.end());
150 binary.append(entryPoints.begin(), entryPoints.end());
151 binary.append(executionModes.begin(), executionModes.end());
152 binary.append(debug.begin(), debug.end());
153 binary.append(names.begin(), names.end());
154 binary.append(decorations.begin(), decorations.end());
155 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
156 binary.append(functions.begin(), functions.end());
161 os <<
"\n= Value <id> Map =\n\n";
162 for (
auto valueIDPair : valueIDMap) {
163 Value val = valueIDPair.first;
164 os <<
" " << val <<
" "
165 <<
"id = " << valueIDPair.second <<
' ';
167 os <<
"from op '" << op->getName() <<
"'";
168 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
169 Block *block = arg.getOwner();
170 os <<
"from argument of block " << block <<
' ';
182 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
183 auto funcID = funcIDMap.lookup(fnName);
185 funcID = getNextID();
186 funcIDMap[fnName] = funcID;
191 void Serializer::processCapability() {
192 for (
auto cap : module.getVceTriple()->getCapabilities())
194 {
static_cast<uint32_t
>(cap)});
197 void Serializer::processDebugInfo() {
200 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
201 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
202 fileID = getNextID();
204 operands.push_back(fileID);
210 LogicalResult Serializer::processExtension() {
212 llvm::SmallSet<Extension, 4> deducedExts(
213 llvm::from_range, module.getVceTriple()->getExtensions());
214 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
215 if (options.
emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
217 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
218 return module.emitError(
219 "SPV_KHR_non_semantic_info extension not available");
220 deducedExts.insert(nonSemanticInfoExt);
222 for (spirv::Extension ext : deducedExts) {
230 void Serializer::processMemoryModel() {
231 StringAttr memoryModelName = module.getMemoryModelAttrName();
232 auto mm =
static_cast<uint32_t
>(
233 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
236 StringAttr addressingModelName = module.getAddressingModelAttrName();
237 auto am =
static_cast<uint32_t
>(
238 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
247 if (attrName ==
"fp_fast_math_mode")
248 return "FPFastMathMode";
250 if (attrName ==
"fp_rounding_mode")
251 return "FPRoundingMode";
253 if (attrName ==
"cache_control_load_intel")
254 return "CacheControlLoadINTEL";
255 if (attrName ==
"cache_control_store_intel")
256 return "CacheControlStoreINTEL";
258 return llvm::convertToCamelFromSnakeCase(attrName,
true);
261 template <
typename AttrTy,
typename EmitF>
265 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
267 return emitError(loc,
"expecting array attribute of ")
268 << attrName <<
" for " << stringifyDecoration(decoration);
270 if (arrayAttr.empty()) {
271 return emitError(loc,
"expecting non-empty array attribute of ")
272 << attrName <<
" for " << stringifyDecoration(decoration);
274 for (
Attribute attr : arrayAttr.getValue()) {
275 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
276 if (!cacheControlAttr) {
277 return emitError(loc,
"expecting array attribute of ")
278 << attrName <<
" for " << stringifyDecoration(decoration);
282 if (
failed(emitter(cacheControlAttr)))
288 LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
289 Decoration decoration,
292 switch (decoration) {
293 case spirv::Decoration::LinkageAttributes: {
296 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
297 auto linkageName = linkageAttr.getLinkageName();
298 auto linkageType = linkageAttr.getLinkageType().getValue();
302 args.push_back(
static_cast<uint32_t
>(linkageType));
305 case spirv::Decoration::FPFastMathMode:
306 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
307 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
310 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
311 << stringifyDecoration(decoration);
312 case spirv::Decoration::FPRoundingMode:
313 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
314 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
317 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
318 << stringifyDecoration(decoration);
319 case spirv::Decoration::Binding:
320 case spirv::Decoration::DescriptorSet:
321 case spirv::Decoration::Location:
322 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
323 args.push_back(intAttr.getValue().getZExtValue());
326 return emitError(loc,
"expected integer attribute for ")
327 << stringifyDecoration(decoration);
328 case spirv::Decoration::BuiltIn:
329 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
330 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
332 args.push_back(
static_cast<uint32_t
>(*enumVal));
336 << stringifyDecoration(decoration) <<
" decoration attribute "
337 << strAttr.getValue();
339 return emitError(loc,
"expected string attribute for ")
340 << stringifyDecoration(decoration);
341 case spirv::Decoration::Aliased:
342 case spirv::Decoration::AliasedPointer:
343 case spirv::Decoration::Flat:
344 case spirv::Decoration::NonReadable:
345 case spirv::Decoration::NonWritable:
346 case spirv::Decoration::NoPerspective:
347 case spirv::Decoration::NoSignedWrap:
348 case spirv::Decoration::NoUnsignedWrap:
349 case spirv::Decoration::RelaxedPrecision:
350 case spirv::Decoration::Restrict:
351 case spirv::Decoration::RestrictPointer:
352 case spirv::Decoration::NoContraction:
353 case spirv::Decoration::Constant:
354 case spirv::Decoration::Block:
355 case spirv::Decoration::Invariant:
356 case spirv::Decoration::Patch:
359 if (isa<UnitAttr, DecorationAttr>(attr))
362 "expected unit attribute or decoration attribute for ")
363 << stringifyDecoration(decoration);
364 case spirv::Decoration::CacheControlLoadINTEL:
365 return processDecorationList<CacheControlLoadINTELAttr>(
366 loc, decoration, attr,
"CacheControlLoadINTEL",
367 [&](CacheControlLoadINTELAttr attr) {
368 unsigned cacheLevel = attr.getCacheLevel();
369 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
370 return emitDecoration(
371 resultID, decoration,
372 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
374 case spirv::Decoration::CacheControlStoreINTEL:
375 return processDecorationList<CacheControlStoreINTELAttr>(
376 loc, decoration, attr,
"CacheControlStoreINTEL",
377 [&](CacheControlStoreINTELAttr attr) {
378 unsigned cacheLevel = attr.getCacheLevel();
379 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
380 return emitDecoration(
381 resultID, decoration,
382 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
385 return emitError(loc,
"unhandled decoration ")
386 << stringifyDecoration(decoration);
388 return emitDecoration(resultID, decoration, args);
391 LogicalResult Serializer::processDecoration(
Location loc, uint32_t resultID,
393 StringRef attrName = attr.
getName().strref();
395 std::optional<Decoration> decoration =
396 spirv::symbolizeDecoration(decorationName);
399 loc,
"non-argument attributes expected to have snake-case-ified "
400 "decoration name, unhandled attribute with name : ")
403 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
406 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
407 assert(!name.empty() &&
"unexpected empty string for OpName");
412 nameOperands.push_back(resultID);
419 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
423 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
429 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
433 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
438 LogicalResult Serializer::processMemberDecoration(
443 static_cast<uint32_t
>(memberDecoration.
decoration)});
459 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
460 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
461 switch (ptrType.getStorageClass()) {
462 case spirv::StorageClass::PhysicalStorageBuffer:
463 case spirv::StorageClass::PushConstant:
464 case spirv::StorageClass::StorageBuffer:
465 case spirv::StorageClass::Uniform:
466 return isa<spirv::StructType>(ptrType.getPointeeType());
474 LogicalResult Serializer::processType(
Location loc,
Type type,
479 return processTypeImpl(loc, type, typeID, serializationCtx);
483 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
495 IntegerType::SignednessSemantics::Signless);
498 typeID = getTypeID(type);
502 typeID = getNextID();
505 operands.push_back(typeID);
506 auto typeEnum = spirv::Opcode::OpTypeVoid;
507 bool deferSerialization =
false;
509 if ((isa<FunctionType>(type) &&
510 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
512 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
513 deferSerialization, serializationCtx))) {
514 if (deferSerialization)
517 typeIDMap[type] = typeID;
521 if (recursiveStructInfos.count(type) != 0) {
524 for (
auto &ptrInfo : recursiveStructInfos[type]) {
528 ptrOperands.push_back(ptrInfo.pointerTypeID);
529 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
530 ptrOperands.push_back(typeIDMap[type]);
536 recursiveStructInfos[type].clear();
545 LogicalResult Serializer::prepareBasicType(
546 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
549 deferSerialization =
false;
551 if (isVoidType(type)) {
552 typeEnum = spirv::Opcode::OpTypeVoid;
556 if (
auto intType = dyn_cast<IntegerType>(type)) {
557 if (intType.getWidth() == 1) {
558 typeEnum = spirv::Opcode::OpTypeBool;
562 typeEnum = spirv::Opcode::OpTypeInt;
563 operands.push_back(intType.getWidth());
568 operands.push_back(intType.isSigned() ? 1 : 0);
572 if (
auto floatType = dyn_cast<FloatType>(type)) {
573 typeEnum = spirv::Opcode::OpTypeFloat;
574 operands.push_back(floatType.getWidth());
575 if (floatType.isBF16()) {
576 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
581 if (
auto vectorType = dyn_cast<VectorType>(type)) {
582 uint32_t elementTypeID = 0;
583 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
584 serializationCtx))) {
587 typeEnum = spirv::Opcode::OpTypeVector;
588 operands.push_back(elementTypeID);
589 operands.push_back(vectorType.getNumElements());
593 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
594 typeEnum = spirv::Opcode::OpTypeImage;
595 uint32_t sampledTypeID = 0;
596 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
599 llvm::append_values(operands, sampledTypeID,
600 static_cast<uint32_t
>(imageType.getDim()),
601 static_cast<uint32_t
>(imageType.getDepthInfo()),
602 static_cast<uint32_t
>(imageType.getArrayedInfo()),
603 static_cast<uint32_t
>(imageType.getSamplingInfo()),
604 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
605 static_cast<uint32_t
>(imageType.getImageFormat()));
609 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
610 typeEnum = spirv::Opcode::OpTypeArray;
611 uint32_t elementTypeID = 0;
612 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
613 serializationCtx))) {
616 operands.push_back(elementTypeID);
617 if (
auto elementCountID = prepareConstantInt(
619 operands.push_back(elementCountID);
621 return processTypeDecoration(loc, arrayType, resultID);
624 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
625 uint32_t pointeeTypeID = 0;
627 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
630 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
636 forwardPtrOperands.push_back(resultID);
637 forwardPtrOperands.push_back(
638 static_cast<uint32_t
>(ptrType.getStorageClass()));
641 spirv::Opcode::OpTypeForwardPointer,
653 deferSerialization =
true;
657 recursiveStructInfos[structType].push_back(
658 {resultID, ptrType.getStorageClass()});
660 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
665 typeEnum = spirv::Opcode::OpTypePointer;
666 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
667 operands.push_back(pointeeTypeID);
672 if (isInterfaceStructPtrType(ptrType)) {
673 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
674 if (!structType.hasDecoration(spirv::Decoration::Block))
675 if (
failed(emitDecoration(getTypeID(pointeeStruct),
676 spirv::Decoration::Block)))
677 return emitError(loc,
"cannot decorate ")
678 << pointeeStruct <<
" with Block decoration";
684 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
685 uint32_t elementTypeID = 0;
686 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
687 elementTypeID, serializationCtx))) {
690 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
691 operands.push_back(elementTypeID);
692 return processTypeDecoration(loc, runtimeArrayType, resultID);
695 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
696 typeEnum = spirv::Opcode::OpTypeSampledImage;
697 uint32_t imageTypeID = 0;
699 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
702 operands.push_back(imageTypeID);
706 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
707 if (structType.isIdentified()) {
708 if (
failed(processName(resultID, structType.getIdentifier())))
710 serializationCtx.insert(structType.getIdentifier());
713 bool hasOffset = structType.hasOffset();
714 for (
auto elementIndex :
715 llvm::seq<uint32_t>(0, structType.getNumElements())) {
716 uint32_t elementTypeID = 0;
717 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
718 elementTypeID, serializationCtx))) {
721 operands.push_back(elementTypeID);
726 elementIndex, spirv::Decoration::Offset,
728 structType.getMemberOffset(elementIndex))};
729 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
730 return emitError(loc,
"cannot decorate ")
731 << elementIndex <<
"-th member of " << structType
732 <<
" with its offset";
737 structType.getMemberDecorations(memberDecorations);
739 for (
auto &memberDecoration : memberDecorations) {
740 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
741 return emitError(loc,
"cannot decorate ")
742 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
743 <<
"-th member of " << structType <<
" with "
744 << stringifyDecoration(memberDecoration.
decoration);
749 structType.getStructDecorations(structDecorations);
753 if (
failed(processDecorationAttr(loc, resultID,
754 structDecoration.decoration,
755 structDecoration.decorationValue))) {
756 return emitError(loc,
"cannot decorate struct ")
757 << structType <<
" with "
758 << stringifyDecoration(structDecoration.decoration);
762 typeEnum = spirv::Opcode::OpTypeStruct;
764 if (structType.isIdentified())
765 serializationCtx.remove(structType.getIdentifier());
770 if (
auto cooperativeMatrixType =
771 dyn_cast<spirv::CooperativeMatrixType>(type)) {
772 uint32_t elementTypeID = 0;
773 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
774 elementTypeID, serializationCtx))) {
777 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
778 auto getConstantOp = [&](uint32_t id) {
780 return prepareConstantInt(loc, attr);
783 operands, elementTypeID,
784 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
785 getConstantOp(cooperativeMatrixType.getRows()),
786 getConstantOp(cooperativeMatrixType.getColumns()),
787 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
791 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
792 uint32_t elementTypeID = 0;
793 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
794 serializationCtx))) {
797 typeEnum = spirv::Opcode::OpTypeMatrix;
798 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
802 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
803 uint32_t elementTypeID = 0;
805 uint32_t shapeID = 0;
807 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
808 elementTypeID, serializationCtx))) {
811 if (tensorArmType.hasRank()) {
819 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
820 if (rank > 0 && shaped) {
825 shapeID = prepareDenseElementsConstant(
830 shapeID = prepareArrayConstant(
839 typeEnum = spirv::Opcode::OpTypeTensorARM;
840 operands.push_back(elementTypeID);
843 operands.push_back(rankID);
846 operands.push_back(shapeID);
851 return emitError(loc,
"unhandled type in serialization: ") << type;
855 Serializer::prepareFunctionType(
Location loc, FunctionType type,
856 spirv::Opcode &typeEnum,
858 typeEnum = spirv::Opcode::OpTypeFunction;
859 assert(type.getNumResults() <= 1 &&
860 "serialization supports only a single return value");
861 uint32_t resultID = 0;
863 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
867 operands.push_back(resultID);
868 for (
auto &res : type.getInputs()) {
869 uint32_t argTypeID = 0;
870 if (
failed(processType(loc, res, argTypeID))) {
873 operands.push_back(argTypeID);
882 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
884 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
891 if (
auto id = getConstantID(valueAttr)) {
896 if (
failed(processType(loc, constType, typeID))) {
900 uint32_t resultID = 0;
901 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
902 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
904 resultID = prepareDenseElementsConstant(loc, constType, attr,
906 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
907 resultID = prepareArrayConstant(loc, constType, arrayAttr);
911 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
915 constIDMap[valueAttr] = resultID;
919 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
922 if (
failed(processType(loc, constType, typeID))) {
926 uint32_t resultID = getNextID();
928 operands.reserve(attr.size() + 2);
929 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
931 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
932 operands.push_back(elementID);
937 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
946 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
949 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
950 assert(dim <= shapedType.getRank());
951 if (shapedType.getRank() == dim) {
952 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
953 return attr.getType().getElementType().isInteger(1)
954 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
955 : prepareConstantInt(loc,
956 attr.getValues<IntegerAttr>()[index]);
958 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
959 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
965 if (
failed(processType(loc, constType, typeID))) {
969 int64_t numberOfConstituents = shapedType.getDimSize(dim);
970 uint32_t resultID = getNextID();
972 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
973 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
975 if (!innerShape.empty())
983 if (isa<spirv::CooperativeMatrixType>(constType)) {
987 "cannot serialize a non-splat value for a cooperative matrix type");
995 if (
auto elementID = prepareDenseElementsConstant(
996 loc, elementType, valueAttr, shapedType.getRank(), index)) {
997 operands.push_back(elementID);
1001 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1003 {typeID, resultID});
1006 operands.reserve(numberOfConstituents + 2);
1007 for (
int i = 0; i < numberOfConstituents; ++i) {
1009 if (
auto elementID = prepareDenseElementsConstant(
1010 loc, elementType, valueAttr, dim + 1, index)) {
1011 operands.push_back(elementID);
1017 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1025 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1026 return prepareConstantFp(loc, floatAttr, isSpec);
1028 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1029 return prepareConstantBool(loc, boolAttr, isSpec);
1031 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1032 return prepareConstantInt(loc, intAttr, isSpec);
1042 if (
auto id = getConstantID(boolAttr)) {
1048 uint32_t typeID = 0;
1049 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1053 auto resultID = getNextID();
1055 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1056 : spirv::Opcode::OpConstantTrue)
1057 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1058 : spirv::Opcode::OpConstantFalse);
1062 constIDMap[boolAttr] = resultID;
1067 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
1071 if (
auto id = getConstantID(intAttr)) {
1077 uint32_t typeID = 0;
1078 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1082 auto resultID = getNextID();
1083 APInt value = intAttr.getValue();
1084 unsigned bitwidth = value.getBitWidth();
1085 bool isSigned = intAttr.getType().isSignedInteger();
1087 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1100 word =
static_cast<int32_t
>(value.getSExtValue());
1102 word =
static_cast<uint32_t
>(value.getZExtValue());
1114 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1116 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1119 {typeID, resultID, words.word1, words.word2});
1122 std::string valueStr;
1123 llvm::raw_string_ostream rss(valueStr);
1124 value.print(rss,
false);
1127 << bitwidth <<
"-bit integer literal: " << valueStr;
1133 constIDMap[intAttr] = resultID;
1138 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
1142 if (
auto id = getConstantID(floatAttr)) {
1148 uint32_t typeID = 0;
1149 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1153 auto resultID = getNextID();
1154 APFloat value = floatAttr.getValue();
1155 const llvm::fltSemantics *semantics = &value.getSemantics();
1158 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1160 if (semantics == &APFloat::IEEEsingle()) {
1161 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1163 }
else if (semantics == &APFloat::IEEEdouble()) {
1167 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1169 {typeID, resultID, words.word1, words.word2});
1170 }
else if (semantics == &APFloat::IEEEhalf() ||
1171 semantics == &APFloat::BFloat()) {
1173 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1176 std::string valueStr;
1177 llvm::raw_string_ostream rss(valueStr);
1181 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1186 constIDMap[floatAttr] = resultID;
1195 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1196 return typedAttr.getType();
1199 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1206 uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1209 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1210 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1214 uint32_t typeID = 0;
1215 if (
failed(processType(loc, resultType, typeID))) {
1223 auto compositeType = dyn_cast<CompositeType>(resultType);
1226 Type elementType = compositeType.getElementType(0);
1228 uint32_t constandID;
1229 if (elementType == valueType) {
1230 constandID = prepareConstant(loc, elementType, valueAttr);
1232 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1235 uint32_t resultID = getNextID();
1236 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1238 {typeID, resultID});
1241 spirv::Opcode::OpConstantCompositeReplicateEXT,
1242 {typeID, resultID, constandID});
1245 constCompositeReplicateIDMap[valueTypePair] = resultID;
1253 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1254 if (uint32_t
id = getBlockID(block))
1256 return blockIDMap[block] = getNextID();
1260 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1261 os <<
"block " << block <<
" (id = ";
1262 if (uint32_t
id = getBlockID(block))
1271 Serializer::processBlock(
Block *block,
bool omitLabel,
1273 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1274 LLVM_DEBUG(block->
print(llvm::dbgs()));
1275 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1277 uint32_t blockID = getOrCreateBlockID(block);
1278 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1285 if (
failed(emitPhiForBlockArguments(block)))
1295 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1298 emitMerge =
nullptr;
1301 uint32_t blockID = getNextID();
1307 for (
Operation &op : llvm::drop_end(*block)) {
1308 if (
failed(processOperation(&op)))
1316 if (
failed(processOperation(&block->
back())))
1322 LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1328 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1337 auto *terminator = mlirPredecessor->getTerminator();
1338 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1339 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1340 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1349 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1350 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1351 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1352 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1353 }
else if (
auto branchCondOp =
1354 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1355 std::optional<OperandRange> blockOperands;
1356 if (branchCondOp.getTrueTarget() == block) {
1357 blockOperands = branchCondOp.getTrueTargetOperands();
1359 assert(branchCondOp.getFalseTarget() == block);
1360 blockOperands = branchCondOp.getFalseTargetOperands();
1363 assert(!blockOperands->empty() &&
1364 "expected non-empty block operand range");
1365 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1367 return terminator->emitError(
"unimplemented terminator for Phi creation");
1370 llvm::dbgs() <<
" block arguments:\n";
1371 for (
Value v : predecessors.back().second)
1372 llvm::dbgs() <<
" " << v <<
"\n";
1377 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1381 uint32_t phiTypeID = 0;
1384 uint32_t phiID = getNextID();
1386 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1387 << arg <<
" (id = " << phiID <<
")\n");
1391 phiArgs.push_back(phiTypeID);
1392 phiArgs.push_back(phiID);
1394 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1395 Value value = predecessors[predIndex].second[argIndex];
1396 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1397 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1398 <<
") value " << value <<
' ');
1400 uint32_t valueId = getValueID(value);
1404 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1405 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1408 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1410 phiArgs.push_back(valueId);
1412 phiArgs.push_back(predBlockId);
1416 valueIDMap[arg] = phiID;
1426 LogicalResult Serializer::encodeExtensionInstruction(
1427 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1430 auto &setID = extendedInstSetIDMap[extensionSetName];
1432 setID = getNextID();
1434 importOperands.push_back(setID);
1442 if (operands.size() < 2) {
1443 return op->
emitError(
"extended instructions must have a result encoding");
1446 extInstOperands.reserve(operands.size() + 2);
1447 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1448 extInstOperands.push_back(setID);
1449 extInstOperands.push_back(extensionOpcode);
1450 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1456 LogicalResult Serializer::processOperation(
Operation *opInst) {
1457 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1462 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1463 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1464 .Case([&](spirv::BranchConditionalOp op) {
1465 return processBranchConditionalOp(op);
1467 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1468 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1469 return processConstantCompositeReplicateOp(op);
1471 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1472 .Case([&](spirv::GlobalVariableOp op) {
1473 return processGlobalVariableOp(op);
1475 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1476 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1477 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1478 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1479 .Case([&](spirv::SpecConstantCompositeOp op) {
1480 return processSpecConstantCompositeOp(op);
1482 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1483 return processSpecConstantCompositeReplicateOp(op);
1485 .Case([&](spirv::SpecConstantOperationOp op) {
1486 return processSpecConstantOperationOp(op);
1488 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1489 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1494 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1497 LogicalResult Serializer::processOpWithoutGrammarAttr(
Operation *op,
1498 StringRef extInstSet,
1503 uint32_t resultID = 0;
1505 uint32_t resultTypeID = 0;
1508 operands.push_back(resultTypeID);
1510 resultID = getNextID();
1511 operands.push_back(resultID);
1512 valueIDMap[op->
getResult(0)] = resultID;
1516 operands.push_back(getValueID(operand));
1518 if (
failed(emitDebugLine(functionBody, loc)))
1521 if (extInstSet.empty()) {
1525 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1531 if (
failed(processDecoration(loc, resultID, attr)))
1539 LogicalResult Serializer::emitDecoration(uint32_t target,
1540 spirv::Decoration decoration,
1542 uint32_t wordCount = 3 + params.size();
1543 llvm::append_values(
1546 static_cast<uint32_t
>(decoration));
1547 llvm::append_range(decorations, params);
1556 if (lastProcessedWasMergeInst) {
1557 lastProcessedWasMergeInst =
false;
1561 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1564 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static bool isZeroValue(Attribute attr)
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
An attribute that represents a reference to a dense vector or tensor object.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Attribute decorationValue