21 #include "llvm/ADT/Sequence.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/ADT/bit.h"
26 #include "llvm/Support/Debug.h"
30 #define DEBUG_TYPE "spirv-serialization"
37 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
38 return selectionOp.getMergeBlock();
39 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
40 return loopOp.getMergeBlock();
51 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
55 while ((op = op->getPrevNode()) !=
nullptr)
80 uint32_t wordCount = 1 + operands.size();
82 binary.append(operands.begin(), operands.end());
90 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
92 if (
failed(module.verifyInvariants()))
103 for (
auto &op : *module.getBody()) {
104 if (
failed(processOperation(&op))) {
109 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
115 extensions.size() + extendedSets.size() +
116 memoryModel.size() + entryPoints.size() +
117 executionModes.size() + decorations.size() +
118 typesGlobalValues.size() + functions.size();
121 binary.reserve(moduleSize);
125 binary.append(capabilities.begin(), capabilities.end());
126 binary.append(extensions.begin(), extensions.end());
127 binary.append(extendedSets.begin(), extendedSets.end());
128 binary.append(memoryModel.begin(), memoryModel.end());
129 binary.append(entryPoints.begin(), entryPoints.end());
130 binary.append(executionModes.begin(), executionModes.end());
131 binary.append(debug.begin(), debug.end());
132 binary.append(names.begin(), names.end());
133 binary.append(decorations.begin(), decorations.end());
134 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
135 binary.append(functions.begin(), functions.end());
140 os <<
"\n= Value <id> Map =\n\n";
141 for (
auto valueIDPair : valueIDMap) {
142 Value val = valueIDPair.first;
143 os <<
" " << val <<
" "
144 <<
"id = " << valueIDPair.second <<
' ';
146 os <<
"from op '" << op->
getName() <<
"'";
147 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
148 Block *block = arg.getOwner();
149 os <<
"from argument of block " << block <<
' ';
161 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
162 auto funcID = funcIDMap.lookup(fnName);
164 funcID = getNextID();
165 funcIDMap[fnName] = funcID;
170 void Serializer::processCapability() {
171 for (
auto cap : module.getVceTriple()->getCapabilities())
173 {
static_cast<uint32_t
>(cap)});
176 void Serializer::processDebugInfo() {
179 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
180 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
181 fileID = getNextID();
183 operands.push_back(fileID);
189 void Serializer::processExtension() {
191 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
198 void Serializer::processMemoryModel() {
199 auto mm =
static_cast<uint32_t
>(
200 module->getAttrOfType<spirv::MemoryModelAttr>(
"memory_model").getValue());
201 auto am =
static_cast<uint32_t
>(
202 module->getAttrOfType<spirv::AddressingModelAttr>(
"addressing_model")
210 auto attrName = attr.
getName().strref();
211 auto decorationName =
212 llvm::convertToCamelFromSnakeCase(attrName,
true);
213 auto decoration = spirv::symbolizeDecoration(decorationName);
216 loc,
"non-argument attributes expected to have snake-case-ified "
217 "decoration name, unhandled attribute with name : ")
221 switch (*decoration) {
222 case spirv::Decoration::LinkageAttributes: {
225 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr.
getValue());
226 auto linkageName = linkageAttr.getLinkageName();
227 auto linkageType = linkageAttr.getLinkageType().getValue();
231 args.push_back(
static_cast<uint32_t
>(linkageType));
234 case spirv::Decoration::Binding:
235 case spirv::Decoration::DescriptorSet:
236 case spirv::Decoration::Location:
237 if (
auto intAttr = dyn_cast<IntegerAttr>(attr.
getValue())) {
238 args.push_back(intAttr.getValue().getZExtValue());
241 return emitError(loc,
"expected integer attribute for ") << attrName;
242 case spirv::Decoration::BuiltIn:
243 if (
auto strAttr = dyn_cast<StringAttr>(attr.
getValue())) {
244 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
246 args.push_back(
static_cast<uint32_t
>(*enumVal));
250 << attrName <<
" attribute " << strAttr.getValue();
252 return emitError(loc,
"expected string attribute for ") << attrName;
253 case spirv::Decoration::Aliased:
254 case spirv::Decoration::Flat:
255 case spirv::Decoration::NonReadable:
256 case spirv::Decoration::NonWritable:
257 case spirv::Decoration::NoPerspective:
258 case spirv::Decoration::Restrict:
259 case spirv::Decoration::RelaxedPrecision:
261 if (
auto unitAttr = dyn_cast<UnitAttr>(attr.
getValue()))
263 return emitError(loc,
"expected unit attribute for ") << attrName;
265 return emitError(loc,
"unhandled decoration ") << decorationName;
267 return emitDecoration(resultID, *decoration, args);
270 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
271 assert(!name.empty() &&
"unexpected empty string for OpName");
276 nameOperands.push_back(resultID);
287 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
297 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
307 static_cast<uint32_t
>(memberDecoration.
decoration)});
322 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
323 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
324 switch (ptrType.getStorageClass()) {
325 case spirv::StorageClass::PhysicalStorageBuffer:
326 case spirv::StorageClass::PushConstant:
327 case spirv::StorageClass::StorageBuffer:
328 case spirv::StorageClass::Uniform:
329 return isa<spirv::StructType>(ptrType.getPointeeType());
342 return processTypeImpl(loc, type, typeID, serializationCtx);
346 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
348 typeID = getTypeID(type);
352 typeID = getNextID();
355 operands.push_back(typeID);
356 auto typeEnum = spirv::Opcode::OpTypeVoid;
357 bool deferSerialization =
false;
359 if ((isa<FunctionType>(type) &&
360 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
362 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
363 deferSerialization, serializationCtx))) {
364 if (deferSerialization)
367 typeIDMap[type] = typeID;
371 if (recursiveStructInfos.count(type) != 0) {
374 for (
auto &ptrInfo : recursiveStructInfos[type]) {
378 ptrOperands.push_back(ptrInfo.pointerTypeID);
379 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
380 ptrOperands.push_back(typeIDMap[type]);
386 recursiveStructInfos[type].clear();
396 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
399 deferSerialization =
false;
401 if (isVoidType(type)) {
402 typeEnum = spirv::Opcode::OpTypeVoid;
406 if (
auto intType = dyn_cast<IntegerType>(type)) {
407 if (intType.getWidth() == 1) {
408 typeEnum = spirv::Opcode::OpTypeBool;
412 typeEnum = spirv::Opcode::OpTypeInt;
413 operands.push_back(intType.getWidth());
418 operands.push_back(intType.isSigned() ? 1 : 0);
422 if (
auto floatType = dyn_cast<FloatType>(type)) {
423 typeEnum = spirv::Opcode::OpTypeFloat;
424 operands.push_back(floatType.getWidth());
428 if (
auto vectorType = dyn_cast<VectorType>(type)) {
429 uint32_t elementTypeID = 0;
430 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
431 serializationCtx))) {
434 typeEnum = spirv::Opcode::OpTypeVector;
435 operands.push_back(elementTypeID);
436 operands.push_back(vectorType.getNumElements());
440 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
441 typeEnum = spirv::Opcode::OpTypeImage;
442 uint32_t sampledTypeID = 0;
443 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
446 operands.push_back(sampledTypeID);
447 operands.push_back(
static_cast<uint32_t
>(imageType.getDim()));
448 operands.push_back(
static_cast<uint32_t
>(imageType.getDepthInfo()));
449 operands.push_back(
static_cast<uint32_t
>(imageType.getArrayedInfo()));
450 operands.push_back(
static_cast<uint32_t
>(imageType.getSamplingInfo()));
451 operands.push_back(
static_cast<uint32_t
>(imageType.getSamplerUseInfo()));
452 operands.push_back(
static_cast<uint32_t
>(imageType.getImageFormat()));
456 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
457 typeEnum = spirv::Opcode::OpTypeArray;
458 uint32_t elementTypeID = 0;
459 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
460 serializationCtx))) {
463 operands.push_back(elementTypeID);
464 if (
auto elementCountID = prepareConstantInt(
466 operands.push_back(elementCountID);
468 return processTypeDecoration(loc, arrayType, resultID);
471 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
472 uint32_t pointeeTypeID = 0;
474 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
477 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
483 forwardPtrOperands.push_back(resultID);
484 forwardPtrOperands.push_back(
485 static_cast<uint32_t
>(ptrType.getStorageClass()));
488 spirv::Opcode::OpTypeForwardPointer,
500 deferSerialization =
true;
504 recursiveStructInfos[structType].push_back(
505 {resultID, ptrType.getStorageClass()});
507 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
512 typeEnum = spirv::Opcode::OpTypePointer;
513 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
514 operands.push_back(pointeeTypeID);
516 if (isInterfaceStructPtrType(ptrType)) {
517 if (
failed(emitDecoration(getTypeID(pointeeStruct),
518 spirv::Decoration::Block)))
519 return emitError(loc,
"cannot decorate ")
520 << pointeeStruct <<
" with Block decoration";
526 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
527 uint32_t elementTypeID = 0;
528 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
529 elementTypeID, serializationCtx))) {
532 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
533 operands.push_back(elementTypeID);
534 return processTypeDecoration(loc, runtimeArrayType, resultID);
537 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
538 typeEnum = spirv::Opcode::OpTypeSampledImage;
539 uint32_t imageTypeID = 0;
541 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
544 operands.push_back(imageTypeID);
548 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
549 if (structType.isIdentified()) {
550 if (
failed(processName(resultID, structType.getIdentifier())))
552 serializationCtx.insert(structType.getIdentifier());
555 bool hasOffset = structType.hasOffset();
556 for (
auto elementIndex :
557 llvm::seq<uint32_t>(0, structType.getNumElements())) {
558 uint32_t elementTypeID = 0;
559 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
560 elementTypeID, serializationCtx))) {
563 operands.push_back(elementTypeID);
567 elementIndex, 1, spirv::Decoration::Offset,
568 static_cast<uint32_t
>(structType.getMemberOffset(elementIndex))};
569 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
570 return emitError(loc,
"cannot decorate ")
571 << elementIndex <<
"-th member of " << structType
572 <<
" with its offset";
577 structType.getMemberDecorations(memberDecorations);
579 for (
auto &memberDecoration : memberDecorations) {
580 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
581 return emitError(loc,
"cannot decorate ")
582 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
583 <<
"-th member of " << structType <<
" with "
584 << stringifyDecoration(memberDecoration.
decoration);
588 typeEnum = spirv::Opcode::OpTypeStruct;
590 if (structType.isIdentified())
591 serializationCtx.remove(structType.getIdentifier());
596 if (
auto cooperativeMatrixType =
597 dyn_cast<spirv::CooperativeMatrixType>(type)) {
598 uint32_t elementTypeID = 0;
599 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
600 elementTypeID, serializationCtx))) {
603 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
604 auto getConstantOp = [&](uint32_t id) {
606 return prepareConstantInt(loc, attr);
608 operands.push_back(elementTypeID);
610 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())));
611 operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
612 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
614 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
618 if (
auto cooperativeMatrixType =
619 dyn_cast<spirv::CooperativeMatrixNVType>(type)) {
620 uint32_t elementTypeID = 0;
621 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
622 elementTypeID, serializationCtx))) {
625 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
626 auto getConstantOp = [&](uint32_t id) {
628 return prepareConstantInt(loc, attr);
630 operands.push_back(elementTypeID);
632 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())));
633 operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
634 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
638 if (
auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
639 uint32_t elementTypeID = 0;
640 if (
failed(processTypeImpl(loc, jointMatrixType.getElementType(),
641 elementTypeID, serializationCtx))) {
644 typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
645 auto getConstantOp = [&](uint32_t id) {
647 return prepareConstantInt(loc, attr);
649 operands.push_back(elementTypeID);
650 operands.push_back(getConstantOp(jointMatrixType.getRows()));
651 operands.push_back(getConstantOp(jointMatrixType.getColumns()));
652 operands.push_back(getConstantOp(
653 static_cast<uint32_t
>(jointMatrixType.getMatrixLayout())));
655 getConstantOp(
static_cast<uint32_t
>(jointMatrixType.getScope())));
659 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
660 uint32_t elementTypeID = 0;
661 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
662 serializationCtx))) {
665 typeEnum = spirv::Opcode::OpTypeMatrix;
666 operands.push_back(elementTypeID);
667 operands.push_back(matrixType.getNumColumns());
672 return emitError(loc,
"unhandled type in serialization: ") << type;
676 Serializer::prepareFunctionType(
Location loc, FunctionType type,
677 spirv::Opcode &typeEnum,
679 typeEnum = spirv::Opcode::OpTypeFunction;
680 assert(type.getNumResults() <= 1 &&
681 "serialization supports only a single return value");
682 uint32_t resultID = 0;
684 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
688 operands.push_back(resultID);
689 for (
auto &res : type.getInputs()) {
690 uint32_t argTypeID = 0;
691 if (
failed(processType(loc, res, argTypeID))) {
694 operands.push_back(argTypeID);
703 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
705 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
712 if (
auto id = getConstantID(valueAttr)) {
717 if (
failed(processType(loc, constType, typeID))) {
721 uint32_t resultID = 0;
722 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
723 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
725 resultID = prepareDenseElementsConstant(loc, constType, attr,
727 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
728 resultID = prepareArrayConstant(loc, constType, arrayAttr);
732 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
736 constIDMap[valueAttr] = resultID;
740 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
743 if (
failed(processType(loc, constType, typeID))) {
747 uint32_t resultID = getNextID();
749 operands.reserve(attr.size() + 2);
750 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
752 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
753 operands.push_back(elementID);
758 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
767 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
770 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
771 assert(dim <= shapedType.getRank());
772 if (shapedType.getRank() == dim) {
773 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
774 return attr.getType().getElementType().isInteger(1)
775 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
776 : prepareConstantInt(loc,
777 attr.getValues<IntegerAttr>()[index]);
779 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
780 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
786 if (
failed(processType(loc, constType, typeID))) {
790 uint32_t resultID = getNextID();
792 operands.reserve(shapedType.getDimSize(dim) + 2);
793 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
794 for (
int i = 0; i < shapedType.getDimSize(dim); ++i) {
796 if (
auto elementID = prepareDenseElementsConstant(
797 loc, elementType, valueAttr, dim + 1, index)) {
798 operands.push_back(elementID);
803 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
811 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
812 return prepareConstantFp(loc, floatAttr, isSpec);
814 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
815 return prepareConstantBool(loc, boolAttr, isSpec);
817 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
818 return prepareConstantInt(loc, intAttr, isSpec);
828 if (
auto id = getConstantID(boolAttr)) {
835 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
839 auto resultID = getNextID();
841 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
842 : spirv::Opcode::OpConstantTrue)
843 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
844 : spirv::Opcode::OpConstantFalse);
848 constIDMap[boolAttr] = resultID;
853 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
857 if (
auto id = getConstantID(intAttr)) {
864 if (
failed(processType(loc, intAttr.getType(), typeID))) {
868 auto resultID = getNextID();
869 APInt value = intAttr.getValue();
870 unsigned bitwidth = value.getBitWidth();
871 bool isSigned = intAttr.getType().isSignedInteger();
873 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
886 word =
static_cast<int32_t
>(value.getSExtValue());
888 word =
static_cast<uint32_t
>(value.getZExtValue());
900 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
902 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
905 {typeID, resultID, words.word1, words.word2});
908 std::string valueStr;
909 llvm::raw_string_ostream rss(valueStr);
910 value.print(rss,
false);
913 << bitwidth <<
"-bit integer literal: " << rss.str();
919 constIDMap[intAttr] = resultID;
924 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
928 if (
auto id = getConstantID(floatAttr)) {
935 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
939 auto resultID = getNextID();
940 APFloat value = floatAttr.getValue();
941 APInt intValue = value.bitcastToAPInt();
944 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
946 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
947 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
949 }
else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
953 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
955 {typeID, resultID, words.word1, words.word2});
956 }
else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
958 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
961 std::string valueStr;
962 llvm::raw_string_ostream rss(valueStr);
966 << floatAttr.getType() <<
"-typed float literal: " << rss.str();
971 constIDMap[floatAttr] = resultID;
980 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
981 if (uint32_t
id = getBlockID(block))
983 return blockIDMap[block] = getNextID();
987 void Serializer::printBlock(
Block *block, raw_ostream &os) {
988 os <<
"block " << block <<
" (id = ";
989 if (uint32_t
id = getBlockID(block))
998 Serializer::processBlock(
Block *block,
bool omitLabel,
1000 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1001 LLVM_DEBUG(block->
print(llvm::dbgs()));
1002 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1004 uint32_t blockID = getOrCreateBlockID(block);
1005 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1012 if (
failed(emitPhiForBlockArguments(block)))
1021 return isa<spirv::LoopOp, spirv::SelectionOp>(op);
1025 emitMerge =
nullptr;
1028 uint32_t blockID = getNextID();
1034 for (
auto &op : llvm::make_range(block->
begin(), std::prev(block->
end()))) {
1035 if (
failed(processOperation(&op)))
1043 if (
failed(processOperation(&block->
back())))
1055 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1064 auto *terminator = mlirPredecessor->getTerminator();
1065 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1066 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1067 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1076 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1077 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1078 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1079 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1080 }
else if (
auto branchCondOp =
1081 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1082 std::optional<OperandRange> blockOperands;
1083 if (branchCondOp.getTrueTarget() == block) {
1084 blockOperands = branchCondOp.getTrueTargetOperands();
1086 assert(branchCondOp.getFalseTarget() == block);
1087 blockOperands = branchCondOp.getFalseTargetOperands();
1090 assert(!blockOperands->empty() &&
1091 "expected non-empty block operand range");
1092 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1094 return terminator->emitError(
"unimplemented terminator for Phi creation");
1097 llvm::dbgs() <<
" block arguments:\n";
1098 for (
Value v : predecessors.back().second)
1099 llvm::dbgs() <<
" " << v <<
"\n";
1104 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1108 uint32_t phiTypeID = 0;
1111 uint32_t phiID = getNextID();
1113 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1114 << arg <<
" (id = " << phiID <<
")\n");
1118 phiArgs.push_back(phiTypeID);
1119 phiArgs.push_back(phiID);
1121 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1122 Value value = predecessors[predIndex].second[argIndex];
1123 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1124 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1125 <<
") value " << value <<
' ');
1127 uint32_t valueId = getValueID(value);
1131 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1132 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1135 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1137 phiArgs.push_back(valueId);
1139 phiArgs.push_back(predBlockId);
1143 valueIDMap[arg] = phiID;
1154 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1157 auto &setID = extendedInstSetIDMap[extensionSetName];
1159 setID = getNextID();
1161 importOperands.push_back(setID);
1169 if (operands.size() < 2) {
1170 return op->
emitError(
"extended instructions must have a result encoding");
1173 extInstOperands.reserve(operands.size() + 2);
1174 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1175 extInstOperands.push_back(setID);
1176 extInstOperands.push_back(extensionOpcode);
1177 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1184 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1189 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1190 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1191 .Case([&](spirv::BranchConditionalOp op) {
1192 return processBranchConditionalOp(op);
1194 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1195 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1196 .Case([&](spirv::GlobalVariableOp op) {
1197 return processGlobalVariableOp(op);
1199 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1200 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1201 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1202 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1203 .Case([&](spirv::SpecConstantCompositeOp op) {
1204 return processSpecConstantCompositeOp(op);
1206 .Case([&](spirv::SpecConstantOperationOp op) {
1207 return processSpecConstantOperationOp(op);
1209 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1210 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1215 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1219 StringRef extInstSet,
1224 uint32_t resultID = 0;
1226 uint32_t resultTypeID = 0;
1229 operands.push_back(resultTypeID);
1231 resultID = getNextID();
1232 operands.push_back(resultID);
1233 valueIDMap[op->
getResult(0)] = resultID;
1237 operands.push_back(getValueID(operand));
1239 if (
failed(emitDebugLine(functionBody, loc)))
1242 if (extInstSet.empty()) {
1246 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1252 if (
failed(processDecoration(loc, resultID, attr)))
1261 spirv::Decoration decoration,
1263 uint32_t wordCount = 3 + params.size();
1264 decorations.push_back(
1266 decorations.push_back(target);
1267 decorations.push_back(
static_cast<uint32_t
>(decoration));
1268 decorations.append(params.begin(), params.end());
1277 if (lastProcessedWasMergeInst) {
1278 lastProcessedWasMergeInst =
false;
1282 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1285 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.