20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
80 uint32_t wordCount = 1 + operands.size();
82 binary.append(operands.begin(), operands.end());
90 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
92 if (failed(module.verifyInvariants()))
103 for (
auto &op : *module.getBody()) {
104 if (failed(processOperation(&op))) {
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
121 binary.reserve(moduleSize);
125 binary.append(capabilities.begin(), capabilities.end());
126 binary.append(extensions.begin(), extensions.end());
127 binary.append(extendedSets.begin(), extendedSets.end());
128 binary.append(memoryModel.begin(), memoryModel.end());
129 binary.append(entryPoints.begin(), entryPoints.end());
130 binary.append(executionModes.begin(), executionModes.end());
131 binary.append(debug.begin(), debug.end());
132 binary.append(names.begin(), names.end());
133 binary.append(decorations.begin(), decorations.end());
134 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135 binary.append(functions.begin(), functions.end());
140 os <<
"\n= Value <id> Map =\n\n";
141 for (
auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os <<
" " << val <<
" "
144 <<
"id = " << valueIDPair.second <<
' ';
146 os <<
"from op '" << op->getName() <<
"'";
147 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
148 Block *block = arg.getOwner();
149 os <<
"from argument of block " << block <<
' ';
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(fnName);
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
170 void Serializer::processCapability() {
171 for (
auto cap : module.getVceTriple()->getCapabilities())
173 {
static_cast<uint32_t
>(cap)});
176 void Serializer::processDebugInfo() {
179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
181 fileID = getNextID();
183 operands.push_back(fileID);
189 void Serializer::processExtension() {
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
198 void Serializer::processMemoryModel() {
199 StringAttr memoryModelName = module.getMemoryModelAttrName();
200 auto mm =
static_cast<uint32_t
>(
201 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
204 StringAttr addressingModelName = module.getAddressingModelAttrName();
205 auto am =
static_cast<uint32_t
>(
206 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
215 if (attrName ==
"fp_fast_math_mode")
216 return "FPFastMathMode";
218 if (attrName ==
"fp_rounding_mode")
219 return "FPRoundingMode";
221 if (attrName ==
"cache_control_load_intel")
222 return "CacheControlLoadINTEL";
223 if (attrName ==
"cache_control_store_intel")
224 return "CacheControlStoreINTEL";
226 return llvm::convertToCamelFromSnakeCase(attrName,
true);
229 template <
typename AttrTy,
typename EmitF>
233 auto arrayAttr = dyn_cast<ArrayAttr>(attrList);
235 return emitError(loc,
"expecting array attribute of ")
236 << attrName <<
" for " << stringifyDecoration(decoration);
238 if (arrayAttr.empty()) {
239 return emitError(loc,
"expecting non-empty array attribute of ")
240 << attrName <<
" for " << stringifyDecoration(decoration);
242 for (
Attribute attr : arrayAttr.getValue()) {
243 auto cacheControlAttr = dyn_cast<AttrTy>(attr);
244 if (!cacheControlAttr) {
245 return emitError(loc,
"expecting array attribute of ")
246 << attrName <<
" for " << stringifyDecoration(decoration);
250 if (failed(emitter(cacheControlAttr)))
256 LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
257 Decoration decoration,
260 switch (decoration) {
261 case spirv::Decoration::LinkageAttributes: {
264 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
265 auto linkageName = linkageAttr.getLinkageName();
266 auto linkageType = linkageAttr.getLinkageType().getValue();
270 args.push_back(
static_cast<uint32_t
>(linkageType));
273 case spirv::Decoration::FPFastMathMode:
274 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
275 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
278 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
279 << stringifyDecoration(decoration);
280 case spirv::Decoration::FPRoundingMode:
281 if (
auto intAttr = dyn_cast<FPRoundingModeAttr>(attr)) {
282 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
285 return emitError(loc,
"expected FPRoundingModeAttr attribute for ")
286 << stringifyDecoration(decoration);
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::DescriptorSet:
289 case spirv::Decoration::Location:
290 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
291 args.push_back(intAttr.getValue().getZExtValue());
294 return emitError(loc,
"expected integer attribute for ")
295 << stringifyDecoration(decoration);
296 case spirv::Decoration::BuiltIn:
297 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
298 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
300 args.push_back(
static_cast<uint32_t
>(*enumVal));
304 << stringifyDecoration(decoration) <<
" decoration attribute "
305 << strAttr.getValue();
307 return emitError(loc,
"expected string attribute for ")
308 << stringifyDecoration(decoration);
309 case spirv::Decoration::Aliased:
310 case spirv::Decoration::AliasedPointer:
311 case spirv::Decoration::Flat:
312 case spirv::Decoration::NonReadable:
313 case spirv::Decoration::NonWritable:
314 case spirv::Decoration::NoPerspective:
315 case spirv::Decoration::NoSignedWrap:
316 case spirv::Decoration::NoUnsignedWrap:
317 case spirv::Decoration::RelaxedPrecision:
318 case spirv::Decoration::Restrict:
319 case spirv::Decoration::RestrictPointer:
320 case spirv::Decoration::NoContraction:
321 case spirv::Decoration::Constant:
324 if (isa<UnitAttr, DecorationAttr>(attr))
327 "expected unit attribute or decoration attribute for ")
328 << stringifyDecoration(decoration);
329 case spirv::Decoration::CacheControlLoadINTEL:
330 return processDecorationList<CacheControlLoadINTELAttr>(
331 loc, decoration, attr,
"CacheControlLoadINTEL",
332 [&](CacheControlLoadINTELAttr attr) {
333 unsigned cacheLevel = attr.getCacheLevel();
334 LoadCacheControl loadCacheControl = attr.getLoadCacheControl();
335 return emitDecoration(
336 resultID, decoration,
337 {cacheLevel,
static_cast<uint32_t
>(loadCacheControl)});
339 case spirv::Decoration::CacheControlStoreINTEL:
340 return processDecorationList<CacheControlStoreINTELAttr>(
341 loc, decoration, attr,
"CacheControlStoreINTEL",
342 [&](CacheControlStoreINTELAttr attr) {
343 unsigned cacheLevel = attr.getCacheLevel();
344 StoreCacheControl storeCacheControl = attr.getStoreCacheControl();
345 return emitDecoration(
346 resultID, decoration,
347 {cacheLevel,
static_cast<uint32_t
>(storeCacheControl)});
350 return emitError(loc,
"unhandled decoration ")
351 << stringifyDecoration(decoration);
353 return emitDecoration(resultID, decoration, args);
356 LogicalResult Serializer::processDecoration(
Location loc, uint32_t resultID,
358 StringRef attrName = attr.
getName().strref();
360 std::optional<Decoration> decoration =
361 spirv::symbolizeDecoration(decorationName);
364 loc,
"non-argument attributes expected to have snake-case-ified "
365 "decoration name, unhandled attribute with name : ")
368 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
371 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
372 assert(!name.empty() &&
"unexpected empty string for OpName");
377 nameOperands.push_back(resultID);
384 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
388 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
394 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
398 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
403 LogicalResult Serializer::processMemberDecoration(
408 static_cast<uint32_t
>(memberDecoration.
decoration)});
423 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
424 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
425 switch (ptrType.getStorageClass()) {
426 case spirv::StorageClass::PhysicalStorageBuffer:
427 case spirv::StorageClass::PushConstant:
428 case spirv::StorageClass::StorageBuffer:
429 case spirv::StorageClass::Uniform:
430 return isa<spirv::StructType>(ptrType.getPointeeType());
438 LogicalResult Serializer::processType(
Location loc,
Type type,
443 return processTypeImpl(loc, type, typeID, serializationCtx);
447 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
449 typeID = getTypeID(type);
453 typeID = getNextID();
456 operands.push_back(typeID);
457 auto typeEnum = spirv::Opcode::OpTypeVoid;
458 bool deferSerialization =
false;
460 if ((isa<FunctionType>(type) &&
461 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
463 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
464 deferSerialization, serializationCtx))) {
465 if (deferSerialization)
468 typeIDMap[type] = typeID;
472 if (recursiveStructInfos.count(type) != 0) {
475 for (
auto &ptrInfo : recursiveStructInfos[type]) {
479 ptrOperands.push_back(ptrInfo.pointerTypeID);
480 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
481 ptrOperands.push_back(typeIDMap[type]);
487 recursiveStructInfos[type].clear();
496 LogicalResult Serializer::prepareBasicType(
497 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
500 deferSerialization =
false;
502 if (isVoidType(type)) {
503 typeEnum = spirv::Opcode::OpTypeVoid;
507 if (
auto intType = dyn_cast<IntegerType>(type)) {
508 if (intType.getWidth() == 1) {
509 typeEnum = spirv::Opcode::OpTypeBool;
513 typeEnum = spirv::Opcode::OpTypeInt;
514 operands.push_back(intType.getWidth());
519 operands.push_back(intType.isSigned() ? 1 : 0);
523 if (
auto floatType = dyn_cast<FloatType>(type)) {
524 typeEnum = spirv::Opcode::OpTypeFloat;
525 operands.push_back(floatType.getWidth());
526 if (floatType.isBF16()) {
527 operands.push_back(
static_cast<uint32_t
>(spirv::FPEncoding::BFloat16KHR));
532 if (
auto vectorType = dyn_cast<VectorType>(type)) {
533 uint32_t elementTypeID = 0;
534 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
535 serializationCtx))) {
538 typeEnum = spirv::Opcode::OpTypeVector;
539 operands.push_back(elementTypeID);
540 operands.push_back(vectorType.getNumElements());
544 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
545 typeEnum = spirv::Opcode::OpTypeImage;
546 uint32_t sampledTypeID = 0;
547 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
550 llvm::append_values(operands, sampledTypeID,
551 static_cast<uint32_t
>(imageType.getDim()),
552 static_cast<uint32_t
>(imageType.getDepthInfo()),
553 static_cast<uint32_t
>(imageType.getArrayedInfo()),
554 static_cast<uint32_t
>(imageType.getSamplingInfo()),
555 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
556 static_cast<uint32_t
>(imageType.getImageFormat()));
560 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
561 typeEnum = spirv::Opcode::OpTypeArray;
562 uint32_t elementTypeID = 0;
563 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
564 serializationCtx))) {
567 operands.push_back(elementTypeID);
568 if (
auto elementCountID = prepareConstantInt(
570 operands.push_back(elementCountID);
572 return processTypeDecoration(loc, arrayType, resultID);
575 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
576 uint32_t pointeeTypeID = 0;
578 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
581 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
587 forwardPtrOperands.push_back(resultID);
588 forwardPtrOperands.push_back(
589 static_cast<uint32_t
>(ptrType.getStorageClass()));
592 spirv::Opcode::OpTypeForwardPointer,
604 deferSerialization =
true;
608 recursiveStructInfos[structType].push_back(
609 {resultID, ptrType.getStorageClass()});
611 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
616 typeEnum = spirv::Opcode::OpTypePointer;
617 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
618 operands.push_back(pointeeTypeID);
620 if (isInterfaceStructPtrType(ptrType)) {
621 if (failed(emitDecoration(getTypeID(pointeeStruct),
622 spirv::Decoration::Block)))
623 return emitError(loc,
"cannot decorate ")
624 << pointeeStruct <<
" with Block decoration";
630 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
631 uint32_t elementTypeID = 0;
632 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
633 elementTypeID, serializationCtx))) {
636 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
637 operands.push_back(elementTypeID);
638 return processTypeDecoration(loc, runtimeArrayType, resultID);
641 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
642 typeEnum = spirv::Opcode::OpTypeSampledImage;
643 uint32_t imageTypeID = 0;
645 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
648 operands.push_back(imageTypeID);
652 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
653 if (structType.isIdentified()) {
654 if (failed(processName(resultID, structType.getIdentifier())))
656 serializationCtx.insert(structType.getIdentifier());
659 bool hasOffset = structType.hasOffset();
660 for (
auto elementIndex :
661 llvm::seq<uint32_t>(0, structType.getNumElements())) {
662 uint32_t elementTypeID = 0;
663 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
664 elementTypeID, serializationCtx))) {
667 operands.push_back(elementTypeID);
671 elementIndex, 1, spirv::Decoration::Offset,
672 static_cast<uint32_t
>(structType.getMemberOffset(elementIndex))};
673 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
674 return emitError(loc,
"cannot decorate ")
675 << elementIndex <<
"-th member of " << structType
676 <<
" with its offset";
681 structType.getMemberDecorations(memberDecorations);
683 for (
auto &memberDecoration : memberDecorations) {
684 if (failed(processMemberDecoration(resultID, memberDecoration))) {
685 return emitError(loc,
"cannot decorate ")
686 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
687 <<
"-th member of " << structType <<
" with "
688 << stringifyDecoration(memberDecoration.
decoration);
692 typeEnum = spirv::Opcode::OpTypeStruct;
694 if (structType.isIdentified())
695 serializationCtx.remove(structType.getIdentifier());
700 if (
auto cooperativeMatrixType =
701 dyn_cast<spirv::CooperativeMatrixType>(type)) {
702 uint32_t elementTypeID = 0;
703 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
704 elementTypeID, serializationCtx))) {
707 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
708 auto getConstantOp = [&](uint32_t id) {
710 return prepareConstantInt(loc, attr);
713 operands, elementTypeID,
714 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
715 getConstantOp(cooperativeMatrixType.getRows()),
716 getConstantOp(cooperativeMatrixType.getColumns()),
717 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
721 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
722 uint32_t elementTypeID = 0;
723 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
724 serializationCtx))) {
727 typeEnum = spirv::Opcode::OpTypeMatrix;
728 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
733 return emitError(loc,
"unhandled type in serialization: ") << type;
737 Serializer::prepareFunctionType(
Location loc, FunctionType type,
738 spirv::Opcode &typeEnum,
740 typeEnum = spirv::Opcode::OpTypeFunction;
741 assert(type.getNumResults() <= 1 &&
742 "serialization supports only a single return value");
743 uint32_t resultID = 0;
744 if (failed(processType(
745 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
749 operands.push_back(resultID);
750 for (
auto &res : type.getInputs()) {
751 uint32_t argTypeID = 0;
752 if (failed(processType(loc, res, argTypeID))) {
755 operands.push_back(argTypeID);
764 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
766 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
773 if (
auto id = getConstantID(valueAttr)) {
778 if (failed(processType(loc, constType, typeID))) {
782 uint32_t resultID = 0;
783 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
784 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
786 resultID = prepareDenseElementsConstant(loc, constType, attr,
788 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
789 resultID = prepareArrayConstant(loc, constType, arrayAttr);
793 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
797 constIDMap[valueAttr] = resultID;
801 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
804 if (failed(processType(loc, constType, typeID))) {
808 uint32_t resultID = getNextID();
810 operands.reserve(attr.size() + 2);
811 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
813 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
814 operands.push_back(elementID);
819 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
828 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
831 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
832 assert(dim <= shapedType.getRank());
833 if (shapedType.getRank() == dim) {
834 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
835 return attr.getType().getElementType().isInteger(1)
836 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
837 : prepareConstantInt(loc,
838 attr.getValues<IntegerAttr>()[index]);
840 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
841 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
847 if (failed(processType(loc, constType, typeID))) {
851 int64_t numberOfConstituents = shapedType.getDimSize(dim);
852 uint32_t resultID = getNextID();
854 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
860 if (isa<spirv::CooperativeMatrixType>(constType)) {
864 "cannot serialize a non-splat value for a cooperative matrix type");
872 if (
auto elementID = prepareDenseElementsConstant(
873 loc, elementType, valueAttr, shapedType.getRank(), index)) {
874 operands.push_back(elementID);
879 operands.reserve(numberOfConstituents + 2);
880 for (
int i = 0; i < numberOfConstituents; ++i) {
882 if (
auto elementID = prepareDenseElementsConstant(
883 loc, elementType, valueAttr, dim + 1, index)) {
884 operands.push_back(elementID);
890 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
898 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
899 return prepareConstantFp(loc, floatAttr, isSpec);
901 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
902 return prepareConstantBool(loc, boolAttr, isSpec);
904 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
905 return prepareConstantInt(loc, intAttr, isSpec);
915 if (
auto id = getConstantID(boolAttr)) {
922 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
926 auto resultID = getNextID();
928 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
929 : spirv::Opcode::OpConstantTrue)
930 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
931 : spirv::Opcode::OpConstantFalse);
935 constIDMap[boolAttr] = resultID;
940 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
944 if (
auto id = getConstantID(intAttr)) {
951 if (failed(processType(loc, intAttr.getType(), typeID))) {
955 auto resultID = getNextID();
956 APInt value = intAttr.getValue();
957 unsigned bitwidth = value.getBitWidth();
958 bool isSigned = intAttr.getType().isSignedInteger();
960 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
973 word =
static_cast<int32_t
>(value.getSExtValue());
975 word =
static_cast<uint32_t
>(value.getZExtValue());
987 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
989 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
992 {typeID, resultID, words.word1, words.word2});
995 std::string valueStr;
996 llvm::raw_string_ostream rss(valueStr);
997 value.print(rss,
false);
1000 << bitwidth <<
"-bit integer literal: " << valueStr;
1006 constIDMap[intAttr] = resultID;
1011 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
1015 if (
auto id = getConstantID(floatAttr)) {
1021 uint32_t typeID = 0;
1022 if (failed(processType(loc, floatAttr.getType(), typeID))) {
1026 auto resultID = getNextID();
1027 APFloat value = floatAttr.getValue();
1028 const llvm::fltSemantics *semantics = &value.getSemantics();
1031 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
1033 if (semantics == &APFloat::IEEEsingle()) {
1034 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
1036 }
else if (semantics == &APFloat::IEEEdouble()) {
1040 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
1042 {typeID, resultID, words.word1, words.word2});
1043 }
else if (semantics == &APFloat::IEEEhalf() ||
1044 semantics == &APFloat::BFloat()) {
1046 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
1049 std::string valueStr;
1050 llvm::raw_string_ostream rss(valueStr);
1054 << floatAttr.getType() <<
"-typed float literal: " << valueStr;
1059 constIDMap[floatAttr] = resultID;
1068 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
1069 if (uint32_t
id = getBlockID(block))
1071 return blockIDMap[block] = getNextID();
1075 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1076 os <<
"block " << block <<
" (id = ";
1077 if (uint32_t
id = getBlockID(block))
1086 Serializer::processBlock(
Block *block,
bool omitLabel,
1088 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1089 LLVM_DEBUG(block->
print(llvm::dbgs()));
1090 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1092 uint32_t blockID = getOrCreateBlockID(block);
1093 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1100 if (failed(emitPhiForBlockArguments(block)))
1110 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1111 if (failed(emitMerge()))
1113 emitMerge =
nullptr;
1116 uint32_t blockID = getNextID();
1122 for (
Operation &op : llvm::drop_end(*block)) {
1123 if (failed(processOperation(&op)))
1129 if (failed(emitMerge()))
1131 if (failed(processOperation(&block->
back())))
1137 LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1143 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1152 auto *terminator = mlirPredecessor->getTerminator();
1153 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1154 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1155 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1164 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1165 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1166 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1167 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1168 }
else if (
auto branchCondOp =
1169 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1170 std::optional<OperandRange> blockOperands;
1171 if (branchCondOp.getTrueTarget() == block) {
1172 blockOperands = branchCondOp.getTrueTargetOperands();
1174 assert(branchCondOp.getFalseTarget() == block);
1175 blockOperands = branchCondOp.getFalseTargetOperands();
1178 assert(!blockOperands->empty() &&
1179 "expected non-empty block operand range");
1180 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1182 return terminator->emitError(
"unimplemented terminator for Phi creation");
1185 llvm::dbgs() <<
" block arguments:\n";
1186 for (
Value v : predecessors.back().second)
1187 llvm::dbgs() <<
" " << v <<
"\n";
1192 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1196 uint32_t phiTypeID = 0;
1197 if (failed(processType(arg.
getLoc(), arg.
getType(), phiTypeID)))
1199 uint32_t phiID = getNextID();
1201 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1202 << arg <<
" (id = " << phiID <<
")\n");
1206 phiArgs.push_back(phiTypeID);
1207 phiArgs.push_back(phiID);
1209 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1210 Value value = predecessors[predIndex].second[argIndex];
1211 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1212 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1213 <<
") value " << value <<
' ');
1215 uint32_t valueId = getValueID(value);
1219 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1220 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1223 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1225 phiArgs.push_back(valueId);
1227 phiArgs.push_back(predBlockId);
1231 valueIDMap[arg] = phiID;
1241 LogicalResult Serializer::encodeExtensionInstruction(
1242 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1245 auto &setID = extendedInstSetIDMap[extensionSetName];
1247 setID = getNextID();
1249 importOperands.push_back(setID);
1257 if (operands.size() < 2) {
1258 return op->
emitError(
"extended instructions must have a result encoding");
1261 extInstOperands.reserve(operands.size() + 2);
1262 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1263 extInstOperands.push_back(setID);
1264 extInstOperands.push_back(extensionOpcode);
1265 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1271 LogicalResult Serializer::processOperation(
Operation *opInst) {
1272 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1277 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1278 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1279 .Case([&](spirv::BranchConditionalOp op) {
1280 return processBranchConditionalOp(op);
1282 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1283 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1284 .Case([&](spirv::GlobalVariableOp op) {
1285 return processGlobalVariableOp(op);
1287 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1288 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1289 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1290 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1291 .Case([&](spirv::SpecConstantCompositeOp op) {
1292 return processSpecConstantCompositeOp(op);
1294 .Case([&](spirv::SpecConstantOperationOp op) {
1295 return processSpecConstantOperationOp(op);
1297 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1298 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1303 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1306 LogicalResult Serializer::processOpWithoutGrammarAttr(
Operation *op,
1307 StringRef extInstSet,
1312 uint32_t resultID = 0;
1314 uint32_t resultTypeID = 0;
1317 operands.push_back(resultTypeID);
1319 resultID = getNextID();
1320 operands.push_back(resultID);
1321 valueIDMap[op->
getResult(0)] = resultID;
1325 operands.push_back(getValueID(operand));
1327 if (failed(emitDebugLine(functionBody, loc)))
1330 if (extInstSet.empty()) {
1334 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1340 if (failed(processDecoration(loc, resultID, attr)))
1348 LogicalResult Serializer::emitDecoration(uint32_t target,
1349 spirv::Decoration decoration,
1351 uint32_t wordCount = 3 + params.size();
1352 llvm::append_values(
1355 static_cast<uint32_t
>(decoration));
1356 llvm::append_range(decorations, params);
1365 if (lastProcessedWasMergeInst) {
1366 lastProcessedWasMergeInst =
false;
1370 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1373 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
An attribute that represents a reference to a dense vector or tensor object.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
LogicalResult processDecorationList(Location loc, Decoration decoration, Attribute attrList, StringRef attrName, EmitF emitter)
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.