20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
80 uint32_t wordCount = 1 + operands.size();
82 binary.append(operands.begin(), operands.end());
90 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
92 if (failed(module.verifyInvariants()))
103 for (
auto &op : *module.getBody()) {
104 if (failed(processOperation(&op))) {
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
121 binary.reserve(moduleSize);
125 binary.append(capabilities.begin(), capabilities.end());
126 binary.append(extensions.begin(), extensions.end());
127 binary.append(extendedSets.begin(), extendedSets.end());
128 binary.append(memoryModel.begin(), memoryModel.end());
129 binary.append(entryPoints.begin(), entryPoints.end());
130 binary.append(executionModes.begin(), executionModes.end());
131 binary.append(debug.begin(), debug.end());
132 binary.append(names.begin(), names.end());
133 binary.append(decorations.begin(), decorations.end());
134 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135 binary.append(functions.begin(), functions.end());
140 os <<
"\n= Value <id> Map =\n\n";
141 for (
auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os <<
" " << val <<
" "
144 <<
"id = " << valueIDPair.second <<
' ';
146 os <<
"from op '" << op->
getName() <<
"'";
147 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
148 Block *block = arg.getOwner();
149 os <<
"from argument of block " << block <<
' ';
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(fnName);
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
170 void Serializer::processCapability() {
171 for (
auto cap : module.getVceTriple()->getCapabilities())
173 {
static_cast<uint32_t
>(cap)});
176 void Serializer::processDebugInfo() {
179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
181 fileID = getNextID();
183 operands.push_back(fileID);
189 void Serializer::processExtension() {
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
198 void Serializer::processMemoryModel() {
199 StringAttr memoryModelName = module.getMemoryModelAttrName();
200 auto mm =
static_cast<uint32_t
>(
201 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
204 StringAttr addressingModelName = module.getAddressingModelAttrName();
205 auto am =
static_cast<uint32_t
>(
206 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
215 if (attrName ==
"fp_fast_math_mode")
216 return "FPFastMathMode";
218 return llvm::convertToCamelFromSnakeCase(attrName,
true);
221 LogicalResult Serializer::processDecorationAttr(
Location loc, uint32_t resultID,
222 Decoration decoration,
225 switch (decoration) {
226 case spirv::Decoration::LinkageAttributes: {
229 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
230 auto linkageName = linkageAttr.getLinkageName();
231 auto linkageType = linkageAttr.getLinkageType().getValue();
235 args.push_back(
static_cast<uint32_t
>(linkageType));
238 case spirv::Decoration::FPFastMathMode:
239 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
240 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
243 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
244 << stringifyDecoration(decoration);
245 case spirv::Decoration::Binding:
246 case spirv::Decoration::DescriptorSet:
247 case spirv::Decoration::Location:
248 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
249 args.push_back(intAttr.getValue().getZExtValue());
252 return emitError(loc,
"expected integer attribute for ")
253 << stringifyDecoration(decoration);
254 case spirv::Decoration::BuiltIn:
255 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
256 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
258 args.push_back(
static_cast<uint32_t
>(*enumVal));
262 << stringifyDecoration(decoration) <<
" decoration attribute "
263 << strAttr.getValue();
265 return emitError(loc,
"expected string attribute for ")
266 << stringifyDecoration(decoration);
267 case spirv::Decoration::Aliased:
268 case spirv::Decoration::AliasedPointer:
269 case spirv::Decoration::Flat:
270 case spirv::Decoration::NonReadable:
271 case spirv::Decoration::NonWritable:
272 case spirv::Decoration::NoPerspective:
273 case spirv::Decoration::NoSignedWrap:
274 case spirv::Decoration::NoUnsignedWrap:
275 case spirv::Decoration::RelaxedPrecision:
276 case spirv::Decoration::Restrict:
277 case spirv::Decoration::RestrictPointer:
278 case spirv::Decoration::NoContraction:
281 if (isa<UnitAttr, DecorationAttr>(attr))
284 "expected unit attribute or decoration attribute for ")
285 << stringifyDecoration(decoration);
287 return emitError(loc,
"unhandled decoration ")
288 << stringifyDecoration(decoration);
290 return emitDecoration(resultID, decoration, args);
293 LogicalResult Serializer::processDecoration(
Location loc, uint32_t resultID,
295 StringRef attrName = attr.
getName().strref();
297 std::optional<Decoration> decoration =
298 spirv::symbolizeDecoration(decorationName);
301 loc,
"non-argument attributes expected to have snake-case-ified "
302 "decoration name, unhandled attribute with name : ")
305 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
308 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
309 assert(!name.empty() &&
"unexpected empty string for OpName");
314 nameOperands.push_back(resultID);
321 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
325 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
331 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
335 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
340 LogicalResult Serializer::processMemberDecoration(
345 static_cast<uint32_t
>(memberDecoration.
decoration)});
360 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
361 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
362 switch (ptrType.getStorageClass()) {
363 case spirv::StorageClass::PhysicalStorageBuffer:
364 case spirv::StorageClass::PushConstant:
365 case spirv::StorageClass::StorageBuffer:
366 case spirv::StorageClass::Uniform:
367 return isa<spirv::StructType>(ptrType.getPointeeType());
375 LogicalResult Serializer::processType(
Location loc,
Type type,
380 return processTypeImpl(loc, type, typeID, serializationCtx);
384 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
386 typeID = getTypeID(type);
390 typeID = getNextID();
393 operands.push_back(typeID);
394 auto typeEnum = spirv::Opcode::OpTypeVoid;
395 bool deferSerialization =
false;
397 if ((isa<FunctionType>(type) &&
398 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
400 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
401 deferSerialization, serializationCtx))) {
402 if (deferSerialization)
405 typeIDMap[type] = typeID;
409 if (recursiveStructInfos.count(type) != 0) {
412 for (
auto &ptrInfo : recursiveStructInfos[type]) {
416 ptrOperands.push_back(ptrInfo.pointerTypeID);
417 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
418 ptrOperands.push_back(typeIDMap[type]);
424 recursiveStructInfos[type].clear();
433 LogicalResult Serializer::prepareBasicType(
434 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
437 deferSerialization =
false;
439 if (isVoidType(type)) {
440 typeEnum = spirv::Opcode::OpTypeVoid;
444 if (
auto intType = dyn_cast<IntegerType>(type)) {
445 if (intType.getWidth() == 1) {
446 typeEnum = spirv::Opcode::OpTypeBool;
450 typeEnum = spirv::Opcode::OpTypeInt;
451 operands.push_back(intType.getWidth());
456 operands.push_back(intType.isSigned() ? 1 : 0);
460 if (
auto floatType = dyn_cast<FloatType>(type)) {
461 typeEnum = spirv::Opcode::OpTypeFloat;
462 operands.push_back(floatType.getWidth());
466 if (
auto vectorType = dyn_cast<VectorType>(type)) {
467 uint32_t elementTypeID = 0;
468 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
469 serializationCtx))) {
472 typeEnum = spirv::Opcode::OpTypeVector;
473 operands.push_back(elementTypeID);
474 operands.push_back(vectorType.getNumElements());
478 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
479 typeEnum = spirv::Opcode::OpTypeImage;
480 uint32_t sampledTypeID = 0;
481 if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
484 llvm::append_values(operands, sampledTypeID,
485 static_cast<uint32_t
>(imageType.getDim()),
486 static_cast<uint32_t
>(imageType.getDepthInfo()),
487 static_cast<uint32_t
>(imageType.getArrayedInfo()),
488 static_cast<uint32_t
>(imageType.getSamplingInfo()),
489 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
490 static_cast<uint32_t
>(imageType.getImageFormat()));
494 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
495 typeEnum = spirv::Opcode::OpTypeArray;
496 uint32_t elementTypeID = 0;
497 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
498 serializationCtx))) {
501 operands.push_back(elementTypeID);
502 if (
auto elementCountID = prepareConstantInt(
504 operands.push_back(elementCountID);
506 return processTypeDecoration(loc, arrayType, resultID);
509 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
510 uint32_t pointeeTypeID = 0;
512 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
515 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
521 forwardPtrOperands.push_back(resultID);
522 forwardPtrOperands.push_back(
523 static_cast<uint32_t
>(ptrType.getStorageClass()));
526 spirv::Opcode::OpTypeForwardPointer,
538 deferSerialization =
true;
542 recursiveStructInfos[structType].push_back(
543 {resultID, ptrType.getStorageClass()});
545 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
550 typeEnum = spirv::Opcode::OpTypePointer;
551 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
552 operands.push_back(pointeeTypeID);
554 if (isInterfaceStructPtrType(ptrType)) {
555 if (failed(emitDecoration(getTypeID(pointeeStruct),
556 spirv::Decoration::Block)))
557 return emitError(loc,
"cannot decorate ")
558 << pointeeStruct <<
" with Block decoration";
564 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
565 uint32_t elementTypeID = 0;
566 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
567 elementTypeID, serializationCtx))) {
570 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
571 operands.push_back(elementTypeID);
572 return processTypeDecoration(loc, runtimeArrayType, resultID);
575 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
576 typeEnum = spirv::Opcode::OpTypeSampledImage;
577 uint32_t imageTypeID = 0;
579 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
582 operands.push_back(imageTypeID);
586 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
587 if (structType.isIdentified()) {
588 if (failed(processName(resultID, structType.getIdentifier())))
590 serializationCtx.insert(structType.getIdentifier());
593 bool hasOffset = structType.hasOffset();
594 for (
auto elementIndex :
595 llvm::seq<uint32_t>(0, structType.getNumElements())) {
596 uint32_t elementTypeID = 0;
597 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
598 elementTypeID, serializationCtx))) {
601 operands.push_back(elementTypeID);
605 elementIndex, 1, spirv::Decoration::Offset,
606 static_cast<uint32_t
>(structType.getMemberOffset(elementIndex))};
607 if (failed(processMemberDecoration(resultID, offsetDecoration))) {
608 return emitError(loc,
"cannot decorate ")
609 << elementIndex <<
"-th member of " << structType
610 <<
" with its offset";
615 structType.getMemberDecorations(memberDecorations);
617 for (
auto &memberDecoration : memberDecorations) {
618 if (failed(processMemberDecoration(resultID, memberDecoration))) {
619 return emitError(loc,
"cannot decorate ")
620 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
621 <<
"-th member of " << structType <<
" with "
622 << stringifyDecoration(memberDecoration.
decoration);
626 typeEnum = spirv::Opcode::OpTypeStruct;
628 if (structType.isIdentified())
629 serializationCtx.remove(structType.getIdentifier());
634 if (
auto cooperativeMatrixType =
635 dyn_cast<spirv::CooperativeMatrixType>(type)) {
636 uint32_t elementTypeID = 0;
637 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
638 elementTypeID, serializationCtx))) {
641 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
642 auto getConstantOp = [&](uint32_t id) {
644 return prepareConstantInt(loc, attr);
647 operands, elementTypeID,
648 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
649 getConstantOp(cooperativeMatrixType.getRows()),
650 getConstantOp(cooperativeMatrixType.getColumns()),
651 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
655 if (
auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
656 uint32_t elementTypeID = 0;
657 if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
658 elementTypeID, serializationCtx))) {
661 typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
662 auto getConstantOp = [&](uint32_t id) {
664 return prepareConstantInt(loc, attr);
667 operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
668 getConstantOp(jointMatrixType.getColumns()),
669 getConstantOp(
static_cast<uint32_t
>(jointMatrixType.getMatrixLayout())),
670 getConstantOp(
static_cast<uint32_t
>(jointMatrixType.getScope())));
674 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
675 uint32_t elementTypeID = 0;
676 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
677 serializationCtx))) {
680 typeEnum = spirv::Opcode::OpTypeMatrix;
681 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
686 return emitError(loc,
"unhandled type in serialization: ") << type;
690 Serializer::prepareFunctionType(
Location loc, FunctionType type,
691 spirv::Opcode &typeEnum,
693 typeEnum = spirv::Opcode::OpTypeFunction;
694 assert(type.getNumResults() <= 1 &&
695 "serialization supports only a single return value");
696 uint32_t resultID = 0;
697 if (failed(processType(
698 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
702 operands.push_back(resultID);
703 for (
auto &res : type.getInputs()) {
704 uint32_t argTypeID = 0;
705 if (failed(processType(loc, res, argTypeID))) {
708 operands.push_back(argTypeID);
717 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
719 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
726 if (
auto id = getConstantID(valueAttr)) {
731 if (failed(processType(loc, constType, typeID))) {
735 uint32_t resultID = 0;
736 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
737 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
739 resultID = prepareDenseElementsConstant(loc, constType, attr,
741 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
742 resultID = prepareArrayConstant(loc, constType, arrayAttr);
746 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
750 constIDMap[valueAttr] = resultID;
754 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
757 if (failed(processType(loc, constType, typeID))) {
761 uint32_t resultID = getNextID();
763 operands.reserve(attr.size() + 2);
764 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
766 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
767 operands.push_back(elementID);
772 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
781 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
784 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
785 assert(dim <= shapedType.getRank());
786 if (shapedType.getRank() == dim) {
787 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
788 return attr.getType().getElementType().isInteger(1)
789 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
790 : prepareConstantInt(loc,
791 attr.getValues<IntegerAttr>()[index]);
793 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
794 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
800 if (failed(processType(loc, constType, typeID))) {
804 uint32_t resultID = getNextID();
806 operands.reserve(shapedType.getDimSize(dim) + 2);
807 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
808 for (
int i = 0; i < shapedType.getDimSize(dim); ++i) {
810 if (
auto elementID = prepareDenseElementsConstant(
811 loc, elementType, valueAttr, dim + 1, index)) {
812 operands.push_back(elementID);
817 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
825 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
826 return prepareConstantFp(loc, floatAttr, isSpec);
828 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
829 return prepareConstantBool(loc, boolAttr, isSpec);
831 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
832 return prepareConstantInt(loc, intAttr, isSpec);
842 if (
auto id = getConstantID(boolAttr)) {
849 if (failed(processType(loc, cast<IntegerAttr>(boolAttr).
getType(), typeID))) {
853 auto resultID = getNextID();
855 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
856 : spirv::Opcode::OpConstantTrue)
857 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
858 : spirv::Opcode::OpConstantFalse);
862 constIDMap[boolAttr] = resultID;
867 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
871 if (
auto id = getConstantID(intAttr)) {
878 if (failed(processType(loc, intAttr.getType(), typeID))) {
882 auto resultID = getNextID();
883 APInt value = intAttr.getValue();
884 unsigned bitwidth = value.getBitWidth();
885 bool isSigned = intAttr.getType().isSignedInteger();
887 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
900 word =
static_cast<int32_t
>(value.getSExtValue());
902 word =
static_cast<uint32_t
>(value.getZExtValue());
914 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
916 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
919 {typeID, resultID, words.word1, words.word2});
922 std::string valueStr;
923 llvm::raw_string_ostream rss(valueStr);
924 value.print(rss,
false);
927 << bitwidth <<
"-bit integer literal: " << rss.str();
933 constIDMap[intAttr] = resultID;
938 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
942 if (
auto id = getConstantID(floatAttr)) {
949 if (failed(processType(loc, floatAttr.getType(), typeID))) {
953 auto resultID = getNextID();
954 APFloat value = floatAttr.getValue();
955 APInt intValue = value.bitcastToAPInt();
958 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
960 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
961 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
963 }
else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
967 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
969 {typeID, resultID, words.word1, words.word2});
970 }
else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
972 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
975 std::string valueStr;
976 llvm::raw_string_ostream rss(valueStr);
980 << floatAttr.getType() <<
"-typed float literal: " << rss.str();
985 constIDMap[floatAttr] = resultID;
994 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
995 if (uint32_t
id = getBlockID(block))
997 return blockIDMap[block] = getNextID();
1001 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1002 os <<
"block " << block <<
" (id = ";
1003 if (uint32_t
id = getBlockID(block))
1012 Serializer::processBlock(
Block *block,
bool omitLabel,
1014 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1015 LLVM_DEBUG(block->
print(llvm::dbgs()));
1016 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1018 uint32_t blockID = getOrCreateBlockID(block);
1019 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1026 if (failed(emitPhiForBlockArguments(block)))
1036 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1037 if (failed(emitMerge()))
1039 emitMerge =
nullptr;
1042 uint32_t blockID = getNextID();
1048 for (
Operation &op : llvm::drop_end(*block)) {
1049 if (failed(processOperation(&op)))
1055 if (failed(emitMerge()))
1057 if (failed(processOperation(&block->
back())))
1063 LogicalResult Serializer::emitPhiForBlockArguments(
Block *block) {
1069 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1078 auto *terminator = mlirPredecessor->getTerminator();
1079 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1080 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1081 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1090 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1091 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1092 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1093 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1094 }
else if (
auto branchCondOp =
1095 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1096 std::optional<OperandRange> blockOperands;
1097 if (branchCondOp.getTrueTarget() == block) {
1098 blockOperands = branchCondOp.getTrueTargetOperands();
1100 assert(branchCondOp.getFalseTarget() == block);
1101 blockOperands = branchCondOp.getFalseTargetOperands();
1104 assert(!blockOperands->empty() &&
1105 "expected non-empty block operand range");
1106 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1108 return terminator->emitError(
"unimplemented terminator for Phi creation");
1111 llvm::dbgs() <<
" block arguments:\n";
1112 for (
Value v : predecessors.back().second)
1113 llvm::dbgs() <<
" " << v <<
"\n";
1118 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1122 uint32_t phiTypeID = 0;
1123 if (failed(processType(arg.
getLoc(), arg.
getType(), phiTypeID)))
1125 uint32_t phiID = getNextID();
1127 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1128 << arg <<
" (id = " << phiID <<
")\n");
1132 phiArgs.push_back(phiTypeID);
1133 phiArgs.push_back(phiID);
1135 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1136 Value value = predecessors[predIndex].second[argIndex];
1137 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1138 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1139 <<
") value " << value <<
' ');
1141 uint32_t valueId = getValueID(value);
1145 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1146 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1149 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1151 phiArgs.push_back(valueId);
1153 phiArgs.push_back(predBlockId);
1157 valueIDMap[arg] = phiID;
1167 LogicalResult Serializer::encodeExtensionInstruction(
1168 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1171 auto &setID = extendedInstSetIDMap[extensionSetName];
1173 setID = getNextID();
1175 importOperands.push_back(setID);
1183 if (operands.size() < 2) {
1184 return op->
emitError(
"extended instructions must have a result encoding");
1187 extInstOperands.reserve(operands.size() + 2);
1188 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1189 extInstOperands.push_back(setID);
1190 extInstOperands.push_back(extensionOpcode);
1191 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1197 LogicalResult Serializer::processOperation(
Operation *opInst) {
1198 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1203 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1204 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1205 .Case([&](spirv::BranchConditionalOp op) {
1206 return processBranchConditionalOp(op);
1208 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1209 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1210 .Case([&](spirv::GlobalVariableOp op) {
1211 return processGlobalVariableOp(op);
1213 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1214 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1215 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1216 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1217 .Case([&](spirv::SpecConstantCompositeOp op) {
1218 return processSpecConstantCompositeOp(op);
1220 .Case([&](spirv::SpecConstantOperationOp op) {
1221 return processSpecConstantOperationOp(op);
1223 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1224 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1229 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1232 LogicalResult Serializer::processOpWithoutGrammarAttr(
Operation *op,
1233 StringRef extInstSet,
1238 uint32_t resultID = 0;
1240 uint32_t resultTypeID = 0;
1243 operands.push_back(resultTypeID);
1245 resultID = getNextID();
1246 operands.push_back(resultID);
1247 valueIDMap[op->
getResult(0)] = resultID;
1251 operands.push_back(getValueID(operand));
1253 if (failed(emitDebugLine(functionBody, loc)))
1256 if (extInstSet.empty()) {
1260 if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1266 if (failed(processDecoration(loc, resultID, attr)))
1274 LogicalResult Serializer::emitDecoration(uint32_t target,
1275 spirv::Decoration decoration,
1277 uint32_t wordCount = 3 + params.size();
1278 llvm::append_values(
1281 static_cast<uint32_t
>(decoration));
1282 llvm::append_range(decorations, params);
1291 if (lastProcessedWasMergeInst) {
1292 lastProcessedWasMergeInst =
false;
1296 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1299 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.