21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
74 if (
auto floatAttr = dyn_cast<FloatAttr>(attr)) {
75 return floatAttr.getValue().isZero();
77 if (
auto boolAttr = dyn_cast<BoolAttr>(attr)) {
78 return !boolAttr.getValue();
80 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
81 return intAttr.getValue().isZero();
83 if (
auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
86 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
99 uint32_t wordCount = 1 + operands.size();
101 binary.append(operands.begin(), operands.end());
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
111 if (
failed(module.verifyInvariants()))
116 if (
failed(processExtension())) {
119 processMemoryModel();
124 for (
auto &op : *module.getBody()) {
125 if (
failed(processOperation(&op))) {
130 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
136 extensions.size() + extendedSets.size() +
137 memoryModel.size() + entryPoints.size() +
138 executionModes.size() + decorations.size() +
139 typesGlobalValues.size() + functions.size() + graphs.size();
142 binary.reserve(moduleSize);
146 binary.append(capabilities.begin(), capabilities.end());
147 binary.append(extensions.begin(), extensions.end());
148 binary.append(extendedSets.begin(), extendedSets.end());
149 binary.append(memoryModel.begin(), memoryModel.end());
150 binary.append(entryPoints.begin(), entryPoints.end());
151 binary.append(executionModes.begin(), executionModes.end());
152 binary.append(debug.begin(), debug.end());
153 binary.append(names.begin(), names.end());
154 binary.append(decorations.begin(), decorations.end());
155 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
156 binary.append(functions.begin(), functions.end());
157 binary.append(graphs.begin(), graphs.end());
162 os <<
"\n= Value <id> Map =\n\n";
163 for (
auto valueIDPair : valueIDMap) {
164 Value val = valueIDPair.first;
165 os <<
" " << val <<
" "
166 <<
"id = " << valueIDPair.second <<
' ';
168 os <<
"from op '" << op->getName() <<
"'";
169 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
170 Block *block = arg.getOwner();
171 os <<
"from argument of block " << block <<
' ';
183 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
184 auto funcID = funcIDMap.lookup(fnName);
186 funcID = getNextID();
187 funcIDMap[fnName] = funcID;
192 void Serializer::processCapability() {
193 for (
auto cap : module.getVceTriple()->getCapabilities())
195 {
static_cast<uint32_t
>(cap)});
198 void Serializer::processDebugInfo() {
201 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
202 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
203 fileID = getNextID();
205 operands.push_back(fileID);
211 LogicalResult Serializer::processExtension() {
213 llvm::SmallSet<Extension, 4> deducedExts(
214 llvm::from_range, module.getVceTriple()->getExtensions());
215 auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
216 if (options.
emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
218 if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
219 return module.emitError(
220 "SPV_KHR_non_semantic_info extension not available");
221 deducedExts.insert(nonSemanticInfoExt);
223 for (spirv::Extension ext : deducedExts) {
231 void Serializer::processMemoryModel() {
232 StringAttr memoryModelName = module.getMemoryModelAttrName();
233 auto mm =
static_cast<uint32_t
>(
234 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
237 StringAttr addressingModelName = module.getAddressingModelAttrName();
238 auto am =
static_cast<uint32_t
>(
239 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
248 if (attrName ==
"fp_fast_math_mode")
249 return "FPFastMathMode";
251 if (attrName ==
"fp_rounding_mode")
252 return "FPRoundingMode";
254 if (attrName ==
"cache_control_load_intel")
255 return "CacheControlLoadINTEL";
256 if (attrName ==
"cache_control_store_intel")
257 return "CacheControlStoreINTEL";
259 return llvm::convertToCamelFromSnakeCase(attrName,
true);
262 template <
typename AttrTy,
typename EmitF>
266 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
268 return emitError(loc,
"expecting array attribute of ")
269 << attrName <<
" for " << stringifyDecoration(decoration);
271 if (arrayAttr.empty()) {
272 return emitError(loc,
"expecting non-empty array attribute of ")
273 << attrName <<
" for " << stringifyDecoration(decoration);
275 for (
Attribute attr : arrayAttr.getValue()) {
276 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
277 if (!cacheControlAttr) {
278 return emitError(loc,
"expecting array attribute of ")
279 << attrName <<
" for " << stringifyDecoration(decoration);
283 if (
failed(emitter(cacheControlAttr)))
289 LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
290 Decoration decoration,
293 switch (decoration) {
294 case spirv::Decoration::LinkageAttributes: {
297 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
298 auto linkageName = linkageAttr.getLinkageName();
299 auto linkageType = linkageAttr.getLinkageType().getValue();
303 args.push_back(
static_cast<uint32_t
>(linkageType));
306 case spirv::Decoration::FPFastMathMode:
307 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
308 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
311 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
312 << stringifyDecoration(decoration);
313 case spirv::Decoration::FPRoundingMode:
314 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
315 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
318 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
319 << stringifyDecoration(decoration);
320 case spirv::Decoration::Binding:
321 case spirv::Decoration::DescriptorSet:
322 case spirv::Decoration::Location:
323 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
324 args.push_back(intAttr.getValue().getZExtValue());
327 return emitError(loc,
"expected integer attribute for ")
328 << stringifyDecoration(decoration);
329 case spirv::Decoration::BuiltIn:
330 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
331 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
333 args.push_back(
static_cast<uint32_t
>(*enumVal));
337 << stringifyDecoration(decoration) <<
" decoration attribute "
338 << strAttr.getValue();
340 return emitError(loc,
"expected string attribute for ")
341 << stringifyDecoration(decoration);
342 case spirv::Decoration::Aliased:
343 case spirv::Decoration::AliasedPointer:
344 case spirv::Decoration::Flat:
345 case spirv::Decoration::NonReadable:
346 case spirv::Decoration::NonWritable:
347 case spirv::Decoration::NoPerspective:
348 case spirv::Decoration::NoSignedWrap:
349 case spirv::Decoration::NoUnsignedWrap:
350 case spirv::Decoration::RelaxedPrecision:
351 case spirv::Decoration::Restrict:
352 case spirv::Decoration::RestrictPointer:
353 case spirv::Decoration::NoContraction:
354 case spirv::Decoration::Constant:
355 case spirv::Decoration::Block:
356 case spirv::Decoration::Invariant:
357 case spirv::Decoration::Patch:
360 if (isa<UnitAttr, DecorationAttr>(attr))
363 "expected unit attribute or decoration attribute for ")
364 << stringifyDecoration(decoration);
365 case spirv::Decoration::CacheControlLoadINTEL:
366 return processDecorationList<CacheControlLoadINTELAttr>(
367 loc, decoration, attr,
"CacheControlLoadINTEL",
368 [&](CacheControlLoadINTELAttr attr) {
369 unsigned cacheLevel = attr.getCacheLevel();
370 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
371 return emitDecoration(
372 resultID, decoration,
373 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
375 case spirv::Decoration::CacheControlStoreINTEL:
376 return processDecorationList<CacheControlStoreINTELAttr>(
377 loc, decoration, attr,
"CacheControlStoreINTEL",
378 [&](CacheControlStoreINTELAttr attr) {
379 unsigned cacheLevel = attr.getCacheLevel();
380 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
381 return emitDecoration(
382 resultID, decoration,
383 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
386 return emitError(loc,
"unhandled decoration ")
387 << stringifyDecoration(decoration);
389 return emitDecoration(resultID, decoration, args);
392 LogicalResult Serializer::processDecoration(
Location loc, uint32_t resultID,
394 StringRef attrName = attr.
getName().strref();
396 std::optional<Decoration> decoration =
397 spirv::symbolizeDecoration(decorationName);
400 loc,
"non-argument attributes expected to have snake-case-ified "
401 "decoration name, unhandled attribute with name : ")
404 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
407 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
408 assert(!name.empty() &&
"unexpected empty string for OpName");
413 nameOperands.push_back(resultID);
420 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
424 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
430 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
434 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
439 LogicalResult Serializer::processMemberDecoration(
444 static_cast<uint32_t
>(memberDecoration.
decoration)});
460 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
461 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
462 switch (ptrType.getStorageClass()) {
463 case spirv::StorageClass::PhysicalStorageBuffer:
464 case spirv::StorageClass::PushConstant:
465 case spirv::StorageClass::StorageBuffer:
466 case spirv::StorageClass::Uniform:
467 return isa<spirv::StructType>(ptrType.getPointeeType());
475 LogicalResult Serializer::processType(
Location loc,
Type type,
480 return processTypeImpl(loc, type, typeID, serializationCtx);
484 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
496 IntegerType::SignednessSemantics::Signless);
499 typeID = getTypeID(type);
503 typeID = getNextID();
506 operands.push_back(typeID);
507 auto typeEnum = spirv::Opcode::OpTypeVoid;
508 bool deferSerialization =
false;
510 if ((isa<FunctionType>(type) &&
511 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
513 (isa<GraphType>(type) &&
515 prepareGraphType(loc, cast<GraphType>(type), typeEnum, operands))) ||
516 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
517 deferSerialization, serializationCtx))) {
518 if (deferSerialization)
521 typeIDMap[type] = typeID;
525 if (recursiveStructInfos.count(type) != 0) {
528 for (
auto &ptrInfo : recursiveStructInfos[type]) {
532 ptrOperands.push_back(ptrInfo.pointerTypeID);
533 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
534 ptrOperands.push_back(typeIDMap[type]);
540 recursiveStructInfos[type].clear();
546 return emitError(loc,
"failed to process type: ") << type;
549 LogicalResult Serializer::prepareBasicType(
550 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
553 deferSerialization =
false;
555 if (isVoidType(type)) {
556 typeEnum = spirv::Opcode::OpTypeVoid;
560 if (
auto intType = dyn_cast<IntegerType>(type)) {
561 if (intType.getWidth() == 1) {
562 typeEnum = spirv::Opcode::OpTypeBool;
566 typeEnum = spirv::Opcode::OpTypeInt;
567 operands.push_back(intType.getWidth());
572 operands.push_back(intType.isSigned() ? 1 : 0);
576 if (
auto floatType = dyn_cast<FloatType>(type)) {
577 typeEnum = spirv::Opcode::OpTypeFloat;
578 operands.push_back(floatType.getWidth());
579 if (floatType.isBF16()) {
580 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
585 if (
auto vectorType = dyn_cast<VectorType>(type)) {
586 uint32_t elementTypeID = 0;
587 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
588 serializationCtx))) {
591 typeEnum = spirv::Opcode::OpTypeVector;
592 operands.push_back(elementTypeID);
593 operands.push_back(vectorType.getNumElements());
597 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
598 typeEnum = spirv::Opcode::OpTypeImage;
599 uint32_t sampledTypeID = 0;
600 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
603 llvm::append_values(operands, sampledTypeID,
604 static_cast<uint32_t
>(imageType.getDim()),
605 static_cast<uint32_t
>(imageType.getDepthInfo()),
606 static_cast<uint32_t
>(imageType.getArrayedInfo()),
607 static_cast<uint32_t
>(imageType.getSamplingInfo()),
608 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
609 static_cast<uint32_t
>(imageType.getImageFormat()));
613 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
614 typeEnum = spirv::Opcode::OpTypeArray;
615 uint32_t elementTypeID = 0;
616 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
617 serializationCtx))) {
620 operands.push_back(elementTypeID);
621 if (
auto elementCountID = prepareConstantInt(
623 operands.push_back(elementCountID);
625 return processTypeDecoration(loc, arrayType, resultID);
628 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
629 uint32_t pointeeTypeID = 0;
631 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
634 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
640 forwardPtrOperands.push_back(resultID);
641 forwardPtrOperands.push_back(
642 static_cast<uint32_t
>(ptrType.getStorageClass()));
645 spirv::Opcode::OpTypeForwardPointer,
657 deferSerialization =
true;
661 recursiveStructInfos[structType].push_back(
662 {resultID, ptrType.getStorageClass()});
664 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
669 typeEnum = spirv::Opcode::OpTypePointer;
670 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
671 operands.push_back(pointeeTypeID);
676 if (isInterfaceStructPtrType(ptrType)) {
677 auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
678 if (!structType.hasDecoration(spirv::Decoration::Block))
679 if (
failed(emitDecoration(getTypeID(pointeeStruct),
680 spirv::Decoration::Block)))
681 return emitError(loc,
"cannot decorate ")
682 << pointeeStruct <<
" with Block decoration";
688 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
689 uint32_t elementTypeID = 0;
690 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
691 elementTypeID, serializationCtx))) {
694 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
695 operands.push_back(elementTypeID);
696 return processTypeDecoration(loc, runtimeArrayType, resultID);
699 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
700 typeEnum = spirv::Opcode::OpTypeSampledImage;
701 uint32_t imageTypeID = 0;
703 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
706 operands.push_back(imageTypeID);
710 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
711 if (structType.isIdentified()) {
712 if (
failed(processName(resultID, structType.getIdentifier())))
714 serializationCtx.insert(structType.getIdentifier());
717 bool hasOffset = structType.hasOffset();
718 for (
auto elementIndex :
719 llvm::seq<uint32_t>(0, structType.getNumElements())) {
720 uint32_t elementTypeID = 0;
721 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
722 elementTypeID, serializationCtx))) {
725 operands.push_back(elementTypeID);
730 elementIndex, spirv::Decoration::Offset,
732 structType.getMemberOffset(elementIndex))};
733 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
734 return emitError(loc,
"cannot decorate ")
735 << elementIndex <<
"-th member of " << structType
736 <<
" with its offset";
741 structType.getMemberDecorations(memberDecorations);
743 for (
auto &memberDecoration : memberDecorations) {
744 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
745 return emitError(loc,
"cannot decorate ")
746 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
747 <<
"-th member of " << structType <<
" with "
748 << stringifyDecoration(memberDecoration.
decoration);
753 structType.getStructDecorations(structDecorations);
757 if (
failed(processDecorationAttr(loc, resultID,
758 structDecoration.decoration,
759 structDecoration.decorationValue))) {
760 return emitError(loc,
"cannot decorate struct ")
761 << structType <<
" with "
762 << stringifyDecoration(structDecoration.decoration);
766 typeEnum = spirv::Opcode::OpTypeStruct;
768 if (structType.isIdentified())
769 serializationCtx.remove(structType.getIdentifier());
774 if (
auto cooperativeMatrixType =
775 dyn_cast<spirv::CooperativeMatrixType>(type)) {
776 uint32_t elementTypeID = 0;
777 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
778 elementTypeID, serializationCtx))) {
781 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
782 auto getConstantOp = [&](uint32_t id) {
784 return prepareConstantInt(loc, attr);
787 operands, elementTypeID,
788 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
789 getConstantOp(cooperativeMatrixType.getRows()),
790 getConstantOp(cooperativeMatrixType.getColumns()),
791 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
795 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
796 uint32_t elementTypeID = 0;
797 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
798 serializationCtx))) {
801 typeEnum = spirv::Opcode::OpTypeMatrix;
802 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
806 if (
auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
807 uint32_t elementTypeID = 0;
809 uint32_t shapeID = 0;
811 if (
failed(processTypeImpl(loc, tensorArmType.getElementType(),
812 elementTypeID, serializationCtx))) {
815 if (tensorArmType.hasRank()) {
823 bool shaped = llvm::all_of(dims, [](
const auto &dim) {
return dim > 0; });
824 if (rank > 0 && shaped) {
829 shapeID = prepareDenseElementsConstant(
834 shapeID = prepareArrayConstant(
843 typeEnum = spirv::Opcode::OpTypeTensorARM;
844 operands.push_back(elementTypeID);
847 operands.push_back(rankID);
850 operands.push_back(shapeID);
855 return emitError(loc,
"unhandled type in serialization: ") << type;
859 Serializer::prepareFunctionType(
Location loc, FunctionType type,
860 spirv::Opcode &typeEnum,
862 typeEnum = spirv::Opcode::OpTypeFunction;
863 assert(type.getNumResults() <= 1 &&
864 "serialization supports only a single return value");
865 uint32_t resultID = 0;
867 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
871 operands.push_back(resultID);
872 for (
auto &res : type.getInputs()) {
873 uint32_t argTypeID = 0;
874 if (
failed(processType(loc, res, argTypeID))) {
877 operands.push_back(argTypeID);
883 Serializer::prepareGraphType(
Location loc, GraphType type,
884 spirv::Opcode &typeEnum,
886 typeEnum = spirv::Opcode::OpTypeGraphARM;
887 assert(type.getNumResults() >= 1 &&
888 "serialization requires at least a return value");
890 operands.push_back(type.getNumInputs());
892 for (
Type argType : type.getInputs()) {
893 uint32_t argTypeID = 0;
894 if (
failed(processType(loc, argType, argTypeID)))
896 operands.push_back(argTypeID);
899 for (
Type resType : type.getResults()) {
900 uint32_t resTypeID = 0;
901 if (
failed(processType(loc, resType, resTypeID)))
903 operands.push_back(resTypeID);
913 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
915 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
922 if (
auto id = getConstantID(valueAttr)) {
927 if (
failed(processType(loc, constType, typeID))) {
931 uint32_t resultID = 0;
932 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
933 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
935 resultID = prepareDenseElementsConstant(loc, constType, attr,
937 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
938 resultID = prepareArrayConstant(loc, constType, arrayAttr);
942 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
946 constIDMap[valueAttr] = resultID;
950 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
953 if (
failed(processType(loc, constType, typeID))) {
957 uint32_t resultID = getNextID();
959 operands.reserve(attr.size() + 2);
960 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
962 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
963 operands.push_back(elementID);
968 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
977 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
980 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
981 assert(dim <= shapedType.getRank());
982 if (shapedType.getRank() == dim) {
983 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
984 return attr.getType().getElementType().isInteger(1)
985 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
986 : prepareConstantInt(loc,
987 attr.getValues<IntegerAttr>()[index]);
989 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
990 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
996 if (
failed(processType(loc, constType, typeID))) {
1000 int64_t numberOfConstituents = shapedType.getDimSize(dim);
1001 uint32_t resultID = getNextID();
1003 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
1004 if (
auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
1006 if (!innerShape.empty())
1014 if (isa<spirv::CooperativeMatrixType>(constType)) {
1018 "cannot serialize a non-splat value for a cooperative matrix type");
1023 operands.reserve(3);
1026 if (
auto elementID = prepareDenseElementsConstant(
1027 loc, elementType, valueAttr, shapedType.getRank(), index)) {
1028 operands.push_back(elementID);
1032 }
else if (isa<spirv::TensorArmType>(constType) &&
isZeroValue(valueAttr)) {
1034 {typeID, resultID});
1037 operands.reserve(numberOfConstituents + 2);
1038 for (
int i = 0; i < numberOfConstituents; ++i) {
1040 if (
auto elementID = prepareDenseElementsConstant(
1041 loc, elementType, valueAttr, dim + 1, index)) {
1042 operands.push_back(elementID);
1048 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
1056 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
1057 return prepareConstantFp(loc, floatAttr, isSpec);
1059 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
1060 return prepareConstantBool(loc, boolAttr, isSpec);
1062 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
1063 return prepareConstantInt(loc, intAttr, isSpec);
1073 if (
auto id = getConstantID(boolAttr)) {
1079 uint32_t typeID = 0;
1080 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
1084 auto resultID = getNextID();
1086 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
1087 : spirv::Opcode::OpConstantTrue)
1088 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
1089 : spirv::Opcode::OpConstantFalse);
1093 constIDMap[boolAttr] = resultID;
1098 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
1102 if (
auto id = getConstantID(intAttr)) {
1108 uint32_t typeID = 0;
1109 if (
failed(processType(loc, intAttr.getType(), typeID))) {
1113 auto resultID = getNextID();
1114 APInt value = intAttr.getValue();
1115 unsigned bitwidth = value.getBitWidth();
1116 bool isSigned = intAttr.getType().isSignedInteger();
1118 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1131 word =
static_cast<int32_t
>(value.getSExtValue());
1133 word =
static_cast<uint32_t
>(value.getZExtValue());
1145 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
1147 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
1150 {typeID, resultID, words.word1, words.word2});
1153 std::string valueStr;
1154 llvm::raw_string_ostream rss(valueStr);
1155 value.print(rss,
false);
1158 << bitwidth <<
"-bit integer literal: " << valueStr;
1164 constIDMap[intAttr] = resultID;
1169 uint32_t Serializer::prepareGraphConstantId(
Location loc,
Type graphConstType,
1170 IntegerAttr intAttr) {
1172 if (uint32_t
id = getGraphConstantARMId(intAttr)) {
1177 uint32_t typeID = 0;
1178 if (
failed(processType(loc, graphConstType, typeID))) {
1182 uint32_t resultID = getNextID();
1183 APInt value = intAttr.getValue();
1184 unsigned bitwidth = value.getBitWidth();
1185 if (bitwidth > 32) {
1186 emitError(loc,
"Too wide attribute for OpGraphConstantARM: ")
1187 << bitwidth <<
" bits";
1190 bool isSigned = value.isSignedIntN(bitwidth);
1194 word =
static_cast<int32_t
>(value.getSExtValue());
1196 word =
static_cast<uint32_t
>(value.getZExtValue());
1199 {typeID, resultID, word});
1200 graphConstIDMap[intAttr] = resultID;
1204 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
1208 if (
auto id = getConstantID(floatAttr)) {
1214 uint32_t typeID = 0;
1215 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
1219 auto resultID = getNextID();
1220 APFloat value = floatAttr.getValue();
1221 const llvm::fltSemantics *semantics = &value.getSemantics();
1224 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1226 if (semantics == &APFloat::IEEEsingle()) {
1227 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1229 }
else if (semantics == &APFloat::IEEEdouble()) {
1233 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1235 {typeID, resultID, words.word1, words.word2});
1236 }
else if (semantics == &APFloat::IEEEhalf() ||
1237 semantics == &APFloat::BFloat()) {
1239 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1242 std::string valueStr;
1243 llvm::raw_string_ostream rss(valueStr);
1247 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1252 constIDMap[floatAttr] = resultID;
1261 if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
1262 return typedAttr.getType();
1265 if (
auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1272 uint32_t Serializer::prepareConstantCompositeReplicate(
Location loc,
1275 std::pair<Attribute, Type> valueTypePair{valueAttr, resultType};
1276 if (uint32_t
id = getConstantCompositeReplicateID(valueTypePair)) {
1280 uint32_t typeID = 0;
1281 if (
failed(processType(loc, resultType, typeID))) {
1289 auto compositeType = dyn_cast<CompositeType>(resultType);
1292 Type elementType = compositeType.getElementType(0);
1294 uint32_t constandID;
1295 if (elementType == valueType) {
1296 constandID = prepareConstant(loc, elementType, valueAttr);
1298 constandID = prepareConstantCompositeReplicate(loc, elementType, valueAttr);
1301 uint32_t resultID = getNextID();
1302 if (dyn_cast<spirv::TensorArmType>(resultType) &&
isZeroValue(valueAttr)) {
1304 {typeID, resultID});
1307 spirv::Opcode::OpConstantCompositeReplicateEXT,
1308 {typeID, resultID, constandID});
1311 constCompositeReplicateIDMap[valueTypePair] = resultID;
1319 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1320 if (uint32_t
id = getBlockID(block))
1322 return blockIDMap[block] = getNextID();
1326 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1327 os <<
"block " << block <<
" (id = ";
1328 if (uint32_t
id = getBlockID(block))
1337 Serializer::processBlock(
Block *block,
bool omitLabel,
1339 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1340 LLVM_DEBUG(block->
print(llvm::dbgs()));
1341 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1343 uint32_t blockID = getOrCreateBlockID(block);
1344 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1351 if (
failed(emitPhiForBlockArguments(block)))
1361 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1364 emitMerge =
nullptr;
1367 uint32_t blockID = getNextID();
1373 for (
Operation &op : llvm::drop_end(*block)) {
1374 if (
failed(processOperation(&op)))
1382 if (
failed(processOperation(&block->
back())))
1388 LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1394 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1403 auto *terminator = mlirPredecessor->getTerminator();
1404 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1405 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1406 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1415 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1416 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1417 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1418 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1419 }
else if (
auto branchCondOp =
1420 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1421 std::optional<OperandRange> blockOperands;
1422 if (branchCondOp.getTrueTarget() == block) {
1423 blockOperands = branchCondOp.getTrueTargetOperands();
1425 assert(branchCondOp.getFalseTarget() == block);
1426 blockOperands = branchCondOp.getFalseTargetOperands();
1429 assert(!blockOperands->empty() &&
1430 "expected non-empty block operand range");
1431 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1433 return terminator->emitError(
"unimplemented terminator for Phi creation");
1436 llvm::dbgs() <<
" block arguments:\n";
1437 for (
Value v : predecessors.back().second)
1438 llvm::dbgs() <<
" " << v <<
"\n";
1443 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1447 uint32_t phiTypeID = 0;
1450 uint32_t phiID = getNextID();
1452 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1453 << arg <<
" (id = " << phiID <<
")\n");
1457 phiArgs.push_back(phiTypeID);
1458 phiArgs.push_back(phiID);
1460 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1461 Value value = predecessors[predIndex].second[argIndex];
1462 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1463 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1464 <<
") value " << value <<
' ');
1466 uint32_t valueId = getValueID(value);
1470 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1471 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1474 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1476 phiArgs.push_back(valueId);
1478 phiArgs.push_back(predBlockId);
1482 valueIDMap[arg] = phiID;
1492 LogicalResult Serializer::encodeExtensionInstruction(
1493 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1496 auto &setID = extendedInstSetIDMap[extensionSetName];
1498 setID = getNextID();
1500 importOperands.push_back(setID);
1508 if (operands.size() < 2) {
1509 return op->
emitError(
"extended instructions must have a result encoding");
1512 extInstOperands.reserve(operands.size() + 2);
1513 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1514 extInstOperands.push_back(setID);
1515 extInstOperands.push_back(extensionOpcode);
1516 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1522 LogicalResult Serializer::processOperation(
Operation *opInst) {
1523 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1528 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1529 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1530 .Case([&](spirv::BranchConditionalOp op) {
1531 return processBranchConditionalOp(op);
1533 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1534 .Case([&](spirv::EXTConstantCompositeReplicateOp op) {
1535 return processConstantCompositeReplicateOp(op);
1537 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1538 .Case([&](spirv::GraphARMOp op) {
return processGraphARMOp(op); })
1539 .Case([&](spirv::GraphEntryPointARMOp op) {
1540 return processGraphEntryPointARMOp(op);
1542 .Case([&](spirv::GraphOutputsARMOp op) {
1543 return processGraphOutputsARMOp(op);
1545 .Case([&](spirv::GlobalVariableOp op) {
1546 return processGlobalVariableOp(op);
1548 .Case([&](spirv::GraphConstantARMOp op) {
1549 return processGraphConstantARMOp(op);
1551 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1552 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1553 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1554 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1555 .Case([&](spirv::SpecConstantCompositeOp op) {
1556 return processSpecConstantCompositeOp(op);
1558 .Case([&](spirv::EXTSpecConstantCompositeReplicateOp op) {
1559 return processSpecConstantCompositeReplicateOp(op);
1561 .Case([&](spirv::SpecConstantOperationOp op) {
1562 return processSpecConstantOperationOp(op);
1564 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1565 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1570 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1573 LogicalResult Serializer::processOpWithoutGrammarAttr(
Operation *op,
1574 StringRef extInstSet,
1579 uint32_t resultID = 0;
1581 uint32_t resultTypeID = 0;
1584 operands.push_back(resultTypeID);
1586 resultID = getNextID();
1587 operands.push_back(resultID);
1588 valueIDMap[op->
getResult(0)] = resultID;
1592 operands.push_back(getValueID(operand));
1594 if (
failed(emitDebugLine(functionBody, loc)))
1597 if (extInstSet.empty()) {
1601 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1607 if (
failed(processDecoration(loc, resultID, attr)))
1615 LogicalResult Serializer::emitDecoration(uint32_t target,
1616 spirv::Decoration decoration,
1618 uint32_t wordCount = 3 + params.size();
1619 llvm::append_values(
1622 static_cast<uint32_t
>(decoration));
1623 llvm::append_range(decorations, params);
1632 if (lastProcessedWasMergeInst) {
1633 lastProcessedWasMergeInst =
false;
1637 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1640 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static bool isZeroValue(Attribute attr)
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
DenseIntElementsAttr getI32TensorAttr(ArrayRef< int32_t > values)
Tensor-typed DenseIntElementsAttr getters.
An attribute that represents a reference to a dense vector or tensor object.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
static ArrayType get(Type elementType, unsigned elementCount)
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static Type getValueType(Attribute attr)
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op)
Queries the target environment recursively from enclosing symbol table ops containing the given op or...
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.
Attribute decorationValue