20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
80 uint32_t wordCount = 1 + operands.size();
82 binary.append(operands.begin(), operands.end());
90 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
92 if (failed(module.verifyInvariants()))
103 for (
auto &op : *module.getBody()) {
104 if (failed(processOperation(&op))) {
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
121 binary.reserve(moduleSize);
125 binary.append(capabilities.begin(), capabilities.end());
126 binary.append(extensions.begin(), extensions.end());
127 binary.append(extendedSets.begin(), extendedSets.end());
128 binary.append(memoryModel.begin(), memoryModel.end());
129 binary.append(entryPoints.begin(), entryPoints.end());
130 binary.append(executionModes.begin(), executionModes.end());
131 binary.append(debug.begin(), debug.end());
132 binary.append(names.begin(), names.end());
133 binary.append(decorations.begin(), decorations.end());
134 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135 binary.append(functions.begin(), functions.end());
140 os <<
"\n= Value <id> Map =\n\n";
141 for (
auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os <<
" " << val <<
" "
144 <<
"id = " << valueIDPair.second <<
' ';
146 os <<
"from op '" << op->getName() <<
"'";
147 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
148 Block *block = arg.getOwner();
149 os <<
"from argument of block " << block <<
' ';
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(fnName);
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
170 void Serializer::processCapability() {
171 for (
auto cap : module.getVceTriple()->getCapabilities())
173 {
static_cast<uint32_t
>(cap)});
176 void Serializer::processDebugInfo() {
179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
181 fileID = getNextID();
183 operands.push_back(fileID);
189 void Serializer::processExtension() {
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
198 void Serializer::processMemoryModel() {
199 StringAttr memoryModelName = module.getMemoryModelAttrName();
200 auto mm =
static_cast<uint32_t
>(
201 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
204 StringAttr addressingModelName = module.getAddressingModelAttrName();
205 auto am =
static_cast<uint32_t
>(
206 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
215 if (attrName ==
"fp_fast_math_mode")
216 return "FPFastMathMode";
218 if (attrName ==
"fp_rounding_mode")
219 return "FPRoundingMode";
221 if (attrName ==
"cache_control_load_intel")
222 return "CacheControlLoadINTEL";
223 if (attrName ==
"cache_control_store_intel")
224 return "CacheControlStoreINTEL";
226 return llvm::convertToCamelFromSnakeCase(attrName,
true);
229 template <
typename AttrTy,
typename EmitF>
233 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
235 return emitError(loc,
"expecting array attribute of ")
236 << attrName <<
" for " << stringifyDecoration(decoration);
238 if (arrayAttr.empty()) {
239 return emitError(loc,
"expecting non-empty array attribute of ")
240 << attrName <<
" for " << stringifyDecoration(decoration);
242 for (
Attribute attr : arrayAttr.getValue()) {
243 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244 if (!cacheControlAttr) {
245 return emitError(loc,
"expecting array attribute of ")
246 << attrName <<
" for " << stringifyDecoration(decoration);
250 if (failed(emitter(cacheControlAttr)))
256 LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
257 Decoration decoration,
260 switch (decoration) {
261 case spirv::Decoration::LinkageAttributes: {
264 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
265 auto linkageName = linkageAttr.getLinkageName();
266 auto linkageType = linkageAttr.getLinkageType().getValue();
270 args.push_back(
static_cast<uint32_t
>(linkageType));
273 case spirv::Decoration::FPFastMathMode:
274 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
275 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
278 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
279 << stringifyDecoration(decoration);
280 case spirv::Decoration::FPRoundingMode:
281 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
282 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
285 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
286 << stringifyDecoration(decoration);
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::DescriptorSet:
289 case spirv::Decoration::Location:
290 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
291 args.push_back(intAttr.getValue().getZExtValue());
294 return emitError(loc,
"expected integer attribute for ")
295 << stringifyDecoration(decoration);
296 case spirv::Decoration::BuiltIn:
297 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
298 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
300 args.push_back(
static_cast<uint32_t
>(*enumVal));
304 << stringifyDecoration(decoration) <<
" decoration attribute "
305 << strAttr.getValue();
307 return emitError(loc,
"expected string attribute for ")
308 << stringifyDecoration(decoration);
309 case spirv::Decoration::Aliased:
310 case spirv::Decoration::AliasedPointer:
311 case spirv::Decoration::Flat:
312 case spirv::Decoration::NonReadable:
313 case spirv::Decoration::NonWritable:
314 case spirv::Decoration::NoPerspective:
315 case spirv::Decoration::NoSignedWrap:
316 case spirv::Decoration::NoUnsignedWrap:
317 case spirv::Decoration::RelaxedPrecision:
318 case spirv::Decoration::Restrict:
319 case spirv::Decoration::RestrictPointer:
320 case spirv::Decoration::NoContraction:
321 case spirv::Decoration::Constant:
324 if (isa<UnitAttr, DecorationAttr>(attr))
327 "expected unit attribute or decoration attribute for ")
328 << stringifyDecoration(decoration);
329 case spirv::Decoration::CacheControlLoadINTEL:
330 return processDecorationList<CacheControlLoadINTELAttr>(
331 loc, decoration, attr,
"CacheControlLoadINTEL",
332 [&](CacheControlLoadINTELAttr attr) {
333 unsigned cacheLevel = attr.getCacheLevel();
334 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335 return emitDecoration(
336 resultID, decoration,
337 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
339 case spirv::Decoration::CacheControlStoreINTEL:
340 return processDecorationList<CacheControlStoreINTELAttr>(
341 loc, decoration, attr,
"CacheControlStoreINTEL",
342 [&](CacheControlStoreINTELAttr attr) {
343 unsigned cacheLevel = attr.getCacheLevel();
344 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345 return emitDecoration(
346 resultID, decoration,
347 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
350 return emitError(loc,
"unhandled decoration ")
351 << stringifyDecoration(decoration);
353 return emitDecoration(resultID, decoration, args);
356 LogicalResult Serializer::processDecoration(
Location loc, uint32_t resultID,
358 StringRef attrName = attr.
getName().strref();
360 std::optional<Decoration> decoration =
361 spirv::symbolizeDecoration(decorationName);
364 loc,
"non-argument attributes expected to have snake-case-ified "
365 "decoration name, unhandled attribute with name : ")
368 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
371 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372 assert(!name.empty() &&
"unexpected empty string for OpName");
377 nameOperands.push_back(resultID);
384 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
388 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
394 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
398 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
403 LogicalResult Serializer::processMemberDecoration(
408 static_cast<uint32_t
>(memberDecoration.
decoration)});
423 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
424 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
425 switch (ptrType.getStorageClass()) {
426 case spirv::StorageClass::PhysicalStorageBuffer:
427 case spirv::StorageClass::PushConstant:
428 case spirv::StorageClass::StorageBuffer:
429 case spirv::StorageClass::Uniform:
430 return isa<spirv::StructType>(ptrType.getPointeeType());
438 LogicalResult Serializer::processType(
Location loc,
Type type,
443 return processTypeImpl(loc, type, typeID, serializationCtx);
447 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
449 typeID = getTypeID(type);
453 typeID = getNextID();
456 operands.push_back(typeID);
457 auto typeEnum = spirv::Opcode::OpTypeVoid;
458 bool deferSerialization =
false;
460 if ((isa<FunctionType>(type) &&
461 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
463 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
464 deferSerialization, serializationCtx))) {
465 if (deferSerialization)
468 typeIDMap[type] = typeID;
472 if (recursiveStructInfos.count(type) != 0) {
475 for (
auto &ptrInfo : recursiveStructInfos[type]) {
479 ptrOperands.push_back(ptrInfo.pointerTypeID);
480 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
481 ptrOperands.push_back(typeIDMap[type]);
487 recursiveStructInfos[type].clear();
496 LogicalResult Serializer::prepareBasicType(
497 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
500 deferSerialization =
false;
502 if (isVoidType(type)) {
503 typeEnum = spirv::Opcode::OpTypeVoid;
507 if (
auto intType = dyn_cast<IntegerType>(type)) {
508 if (intType.getWidth() == 1) {
509 typeEnum = spirv::Opcode::OpTypeBool;
513 typeEnum = spirv::Opcode::OpTypeInt;
514 operands.push_back(intType.getWidth());
519 operands.push_back(intType.isSigned() ? 1 : 0);
523 if (
auto floatType = dyn_cast<FloatType>(type)) {
524 typeEnum = spirv::Opcode::OpTypeFloat;
525 operands.push_back(floatType.getWidth());
529 if (
auto vectorType = dyn_cast<VectorType>(type)) {
530 uint32_t elementTypeID = 0;
531 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
532 serializationCtx))) {
535 typeEnum = spirv::Opcode::OpTypeVector;
536 operands.push_back(elementTypeID);
537 operands.push_back(vectorType.getNumElements());
541 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
542 typeEnum = spirv::Opcode::OpTypeImage;
543 uint32_t sampledTypeID = 0;
544 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
547 llvm::append_values(operands, sampledTypeID,
548 static_cast<uint32_t
>(imageType.getDim()),
549 static_cast<uint32_t
>(imageType.getDepthInfo()),
550 static_cast<uint32_t
>(imageType.getArrayedInfo()),
551 static_cast<uint32_t
>(imageType.getSamplingInfo()),
552 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
553 static_cast<uint32_t
>(imageType.getImageFormat()));
557 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
558 typeEnum = spirv::Opcode::OpTypeArray;
559 uint32_t elementTypeID = 0;
560 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
561 serializationCtx))) {
564 operands.push_back(elementTypeID);
565 if (
auto elementCountID = prepareConstantInt(
567 operands.push_back(elementCountID);
569 return processTypeDecoration(loc, arrayType, resultID);
572 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
573 uint32_t pointeeTypeID = 0;
575 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
578 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
584 forwardPtrOperands.push_back(resultID);
585 forwardPtrOperands.push_back(
586 static_cast<uint32_t
>(ptrType.getStorageClass()));
589 spirv::Opcode::OpTypeForwardPointer,
601 deferSerialization =
true;
605 recursiveStructInfos[structType].push_back(
606 {resultID, ptrType.getStorageClass()});
608 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
613 typeEnum = spirv::Opcode::OpTypePointer;
614 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
615 operands.push_back(pointeeTypeID);
617 if (isInterfaceStructPtrType(ptrType)) {
618 if (failed(emitDecoration(getTypeID(pointeeStruct),
619 spirv::Decoration::Block)))
620 return emitError(loc,
"cannot decorate ")
621 << pointeeStruct <<
" with Block decoration";
627 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
628 uint32_t elementTypeID = 0;
629 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
630 elementTypeID, serializationCtx))) {
633 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
634 operands.push_back(elementTypeID);
635 return processTypeDecoration(loc, runtimeArrayType, resultID);
638 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
639 typeEnum = spirv::Opcode::OpTypeSampledImage;
640 uint32_t imageTypeID = 0;
642 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
645 operands.push_back(imageTypeID);
649 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
650 if (structType.isIdentified()) {
651 if (failed(processName(resultID, structType.getIdentifier())))
653 serializationCtx.insert(structType.getIdentifier());
656 bool hasOffset = structType.hasOffset();
657 for (
auto elementIndex :
658 llvm::seq<uint32_t>(0, structType.getNumElements())) {
659 uint32_t elementTypeID = 0;
660 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
661 elementTypeID, serializationCtx))) {
664 operands.push_back(elementTypeID);
668 elementIndex, 1, spirv::Decoration::Offset,
669 static_cast<uint32_t
>(structType.getMemberOffset(elementIndex))};
670 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
671 return emitError(loc,
"cannot decorate ")
672 << elementIndex <<
"-th member of " << structType
673 <<
" with its offset";
678 structType.getMemberDecorations(memberDecorations);
680 for (
auto &memberDecoration : memberDecorations) {
681 if (failed(processMemberDecoration(resultID, memberDecoration))) {
682 return emitError(loc,
"cannot decorate ")
683 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
684 <<
"-th member of " << structType <<
" with "
685 << stringifyDecoration(memberDecoration.
decoration);
689 typeEnum = spirv::Opcode::OpTypeStruct;
691 if (structType.isIdentified())
692 serializationCtx.remove(structType.getIdentifier());
697 if (
auto cooperativeMatrixType =
698 dyn_cast<spirv::CooperativeMatrixType>(type)) {
699 uint32_t elementTypeID = 0;
700 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
701 elementTypeID, serializationCtx))) {
704 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
705 auto getConstantOp = [&](uint32_t id) {
707 return prepareConstantInt(loc, attr);
710 operands, elementTypeID,
711 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
712 getConstantOp(cooperativeMatrixType.getRows()),
713 getConstantOp(cooperativeMatrixType.getColumns()),
714 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
718 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
719 uint32_t elementTypeID = 0;
720 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
721 serializationCtx))) {
724 typeEnum = spirv::Opcode::OpTypeMatrix;
725 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
730 return emitError(loc,
"unhandled type in serialization: ") << type;
734 Serializer::prepareFunctionType(
Location loc, FunctionType type,
735 spirv::Opcode &typeEnum,
737 typeEnum = spirv::Opcode::OpTypeFunction;
738 assert(type.getNumResults() <= 1 &&
739 "serialization supports only a single return value");
740 uint32_t resultID = 0;
741 if (failed(processType(
742 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
746 operands.push_back(resultID);
747 for (
auto &res : type.getInputs()) {
748 uint32_t argTypeID = 0;
749 if (failed(processType(loc, res, argTypeID))) {
752 operands.push_back(argTypeID);
761 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
763 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
770 if (
auto id = getConstantID(valueAttr)) {
775 if (failed(processType(loc, constType, typeID))) {
779 uint32_t resultID = 0;
780 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
781 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
783 resultID = prepareDenseElementsConstant(loc, constType, attr,
785 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
786 resultID = prepareArrayConstant(loc, constType, arrayAttr);
790 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
794 constIDMap[valueAttr] = resultID;
798 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
801 if (failed(processType(loc, constType, typeID))) {
805 uint32_t resultID = getNextID();
807 operands.reserve(attr.size() + 2);
808 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
810 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
811 operands.push_back(elementID);
816 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
825 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
828 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
829 assert(dim <= shapedType.getRank());
830 if (shapedType.getRank() == dim) {
831 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
832 return attr.getType().getElementType().isInteger(1)
833 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
834 : prepareConstantInt(loc,
835 attr.getValues<IntegerAttr>()[index]);
837 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
838 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
844 if (failed(processType(loc, constType, typeID))) {
848 uint32_t resultID = getNextID();
850 operands.reserve(shapedType.getDimSize(dim) + 2);
851 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
852 for (
int i = 0; i < shapedType.getDimSize(dim); ++i) {
854 if (
auto elementID = prepareDenseElementsConstant(
855 loc, elementType, valueAttr, dim + 1, index)) {
856 operands.push_back(elementID);
861 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
869 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
870 return prepareConstantFp(loc, floatAttr, isSpec);
872 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
873 return prepareConstantBool(loc, boolAttr, isSpec);
875 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
876 return prepareConstantInt(loc, intAttr, isSpec);
886 if (
auto id = getConstantID(boolAttr)) {
893 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
897 auto resultID = getNextID();
899 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
900 : spirv::Opcode::OpConstantTrue)
901 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
902 : spirv::Opcode::OpConstantFalse);
906 constIDMap[boolAttr] = resultID;
911 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
915 if (
auto id = getConstantID(intAttr)) {
922 if (failed(processType(loc, intAttr.getType(), typeID))) {
926 auto resultID = getNextID();
927 APInt value = intAttr.getValue();
928 unsigned bitwidth = value.getBitWidth();
929 bool isSigned = intAttr.getType().isSignedInteger();
931 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
944 word =
static_cast<int32_t
>(value.getSExtValue());
946 word =
static_cast<uint32_t
>(value.getZExtValue());
958 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
960 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
963 {typeID, resultID, words.word1, words.word2});
966 std::string valueStr;
967 llvm::raw_string_ostream rss(valueStr);
968 value.print(rss,
false);
971 << bitwidth <<
"-bit integer literal: " << valueStr;
977 constIDMap[intAttr] = resultID;
982 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
986 if (
auto id = getConstantID(floatAttr)) {
993 if (failed(processType(loc, floatAttr.getType(), typeID))) {
997 auto resultID = getNextID();
998 APFloat value = floatAttr.getValue();
999 APInt intValue = value.bitcastToAPInt();
1002 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1004 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
1005 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1007 }
else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
1011 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1013 {typeID, resultID, words.word1, words.word2});
1014 }
else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
1016 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1019 std::string valueStr;
1020 llvm::raw_string_ostream rss(valueStr);
1024 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1029 constIDMap[floatAttr] = resultID;
1038 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1039 if (uint32_t
id = getBlockID(block))
1041 return blockIDMap[block] = getNextID();
1045 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1046 os <<
"block " << block <<
" (id = ";
1047 if (uint32_t
id = getBlockID(block))
1056 Serializer::processBlock(
Block *block,
bool omitLabel,
1058 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1059 LLVM_DEBUG(block->
print(llvm::dbgs()));
1060 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1062 uint32_t blockID = getOrCreateBlockID(block);
1063 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1070 if (failed(emitPhiForBlockArguments(block)))
1080 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1081 if (failed(emitMerge()))
1083 emitMerge =
nullptr;
1086 uint32_t blockID = getNextID();
1092 for (
Operation &op : llvm::drop_end(*block)) {
1093 if (failed(processOperation(&op)))
1099 if (failed(emitMerge()))
1101 if (failed(processOperation(&block->
back())))
1107 LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1113 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1122 auto *terminator = mlirPredecessor->getTerminator();
1123 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1124 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1125 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1134 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1135 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1136 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1137 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1138 }
else if (
auto branchCondOp =
1139 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1140 std::optional<OperandRange> blockOperands;
1141 if (branchCondOp.getTrueTarget() == block) {
1142 blockOperands = branchCondOp.getTrueTargetOperands();
1144 assert(branchCondOp.getFalseTarget() == block);
1145 blockOperands = branchCondOp.getFalseTargetOperands();
1148 assert(!blockOperands->empty() &&
1149 "expected non-empty block operand range");
1150 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1152 return terminator->emitError(
"unimplemented terminator for Phi creation");
1155 llvm::dbgs() <<
" block arguments:\n";
1156 for (
Value v : predecessors.back().second)
1157 llvm::dbgs() <<
" " << v <<
"\n";
1162 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1166 uint32_t phiTypeID = 0;
1167 if (failed(processType(arg.
getLoc(), arg.
getType(), phiTypeID)))
1169 uint32_t phiID = getNextID();
1171 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1172 << arg <<
" (id = " << phiID <<
")\n");
1176 phiArgs.push_back(phiTypeID);
1177 phiArgs.push_back(phiID);
1179 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1180 Value value = predecessors[predIndex].second[argIndex];
1181 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1182 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1183 <<
") value " << value <<
' ');
1185 uint32_t valueId = getValueID(value);
1189 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1190 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1193 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1195 phiArgs.push_back(valueId);
1197 phiArgs.push_back(predBlockId);
1201 valueIDMap[arg] = phiID;
1211 LogicalResult Serializer::encodeExtensionInstruction(
1212 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1215 auto &setID = extendedInstSetIDMap[extensionSetName];
1217 setID = getNextID();
1219 importOperands.push_back(setID);
1227 if (operands.size() < 2) {
1228 return op->
emitError(
"extended instructions must have a result encoding");
1231 extInstOperands.reserve(operands.size() + 2);
1232 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1233 extInstOperands.push_back(setID);
1234 extInstOperands.push_back(extensionOpcode);
1235 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1241 LogicalResult Serializer::processOperation(
Operation *opInst) {
1242 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1247 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1248 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1249 .Case([&](spirv::BranchConditionalOp op) {
1250 return processBranchConditionalOp(op);
1252 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1253 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1254 .Case([&](spirv::GlobalVariableOp op) {
1255 return processGlobalVariableOp(op);
1257 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1258 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1259 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1260 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1261 .Case([&](spirv::SpecConstantCompositeOp op) {
1262 return processSpecConstantCompositeOp(op);
1264 .Case([&](spirv::SpecConstantOperationOp op) {
1265 return processSpecConstantOperationOp(op);
1267 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1268 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1273 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1276 LogicalResult Serializer::processOpWithoutGrammarAttr(
Operation *op,
1277 StringRef extInstSet,
1282 uint32_t resultID = 0;
1284 uint32_t resultTypeID = 0;
1287 operands.push_back(resultTypeID);
1289 resultID = getNextID();
1290 operands.push_back(resultID);
1291 valueIDMap[op->
getResult(0)] = resultID;
1295 operands.push_back(getValueID(operand));
1297 if (failed(emitDebugLine(functionBody, loc)))
1300 if (extInstSet.empty()) {
1304 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1310 if (failed(processDecoration(loc, resultID, attr)))
1318 LogicalResult Serializer::emitDecoration(uint32_t target,
1319 spirv::Decoration decoration,
1321 uint32_t wordCount = 3 + params.size();
1322 llvm::append_values(
1325 static_cast<uint32_t
>(decoration));
1326 llvm::append_range(decorations, params);
1335 if (lastProcessedWasMergeInst) {
1336 lastProcessedWasMergeInst =
false;
1340 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1343 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.