21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/Sequence.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/ADT/bit.h"
27 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "spirv-serialization"
38 if (
auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
39 return selectionOp.getMergeBlock();
40 if (
auto loopOp = dyn_cast<spirv::LoopOp>(op))
41 return loopOp.getMergeBlock();
52 if (
auto loopOp = dyn_cast<spirv::LoopOp>(block->
getParentOp())) {
56 while ((op = op->getPrevNode()) !=
nullptr)
81 uint32_t wordCount = 1 + operands.size();
83 binary.append(operands.begin(), operands.end());
91 LLVM_DEBUG(llvm::dbgs() <<
"+++ starting serialization +++\n");
93 if (
failed(module.verifyInvariants()))
104 for (
auto &op : *module.getBody()) {
105 if (
failed(processOperation(&op))) {
110 LLVM_DEBUG(llvm::dbgs() <<
"+++ completed serialization +++\n");
116 extensions.size() + extendedSets.size() +
117 memoryModel.size() + entryPoints.size() +
118 executionModes.size() + decorations.size() +
119 typesGlobalValues.size() + functions.size();
122 binary.reserve(moduleSize);
126 binary.append(capabilities.begin(), capabilities.end());
127 binary.append(extensions.begin(), extensions.end());
128 binary.append(extendedSets.begin(), extendedSets.end());
129 binary.append(memoryModel.begin(), memoryModel.end());
130 binary.append(entryPoints.begin(), entryPoints.end());
131 binary.append(executionModes.begin(), executionModes.end());
132 binary.append(debug.begin(), debug.end());
133 binary.append(names.begin(), names.end());
134 binary.append(decorations.begin(), decorations.end());
135 binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
136 binary.append(functions.begin(), functions.end());
141 os <<
"\n= Value <id> Map =\n\n";
142 for (
auto valueIDPair : valueIDMap) {
143 Value val = valueIDPair.first;
144 os <<
" " << val <<
" "
145 <<
"id = " << valueIDPair.second <<
' ';
147 os <<
"from op '" << op->
getName() <<
"'";
148 }
else if (
auto arg = dyn_cast<BlockArgument>(val)) {
149 Block *block = arg.getOwner();
150 os <<
"from argument of block " << block <<
' ';
162 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
163 auto funcID = funcIDMap.lookup(fnName);
165 funcID = getNextID();
166 funcIDMap[fnName] = funcID;
171 void Serializer::processCapability() {
172 for (
auto cap : module.getVceTriple()->getCapabilities())
174 {
static_cast<uint32_t
>(cap)});
177 void Serializer::processDebugInfo() {
180 auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
181 auto fileName = fileLoc ? fileLoc.getFilename().strref() :
"<unknown>";
182 fileID = getNextID();
184 operands.push_back(fileID);
190 void Serializer::processExtension() {
192 for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
199 void Serializer::processMemoryModel() {
200 StringAttr memoryModelName = module.getMemoryModelAttrName();
201 auto mm =
static_cast<uint32_t
>(
202 module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
205 StringAttr addressingModelName = module.getAddressingModelAttrName();
206 auto am =
static_cast<uint32_t
>(
207 module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
216 if (attrName ==
"fp_fast_math_mode")
217 return "FPFastMathMode";
219 return llvm::convertToCamelFromSnakeCase(attrName,
true);
223 Decoration decoration,
226 switch (decoration) {
227 case spirv::Decoration::LinkageAttributes: {
230 auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
231 auto linkageName = linkageAttr.getLinkageName();
232 auto linkageType = linkageAttr.getLinkageType().getValue();
236 args.push_back(
static_cast<uint32_t
>(linkageType));
239 case spirv::Decoration::FPFastMathMode:
240 if (
auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
241 args.push_back(
static_cast<uint32_t
>(intAttr.getValue()));
244 return emitError(loc,
"expected FPFastMathModeAttr attribute for ")
245 << stringifyDecoration(decoration);
246 case spirv::Decoration::Binding:
247 case spirv::Decoration::DescriptorSet:
248 case spirv::Decoration::Location:
249 if (
auto intAttr = dyn_cast<IntegerAttr>(attr)) {
250 args.push_back(intAttr.getValue().getZExtValue());
253 return emitError(loc,
"expected integer attribute for ")
254 << stringifyDecoration(decoration);
255 case spirv::Decoration::BuiltIn:
256 if (
auto strAttr = dyn_cast<StringAttr>(attr)) {
257 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
259 args.push_back(
static_cast<uint32_t
>(*enumVal));
263 << stringifyDecoration(decoration) <<
" decoration attribute "
264 << strAttr.getValue();
266 return emitError(loc,
"expected string attribute for ")
267 << stringifyDecoration(decoration);
268 case spirv::Decoration::Aliased:
269 case spirv::Decoration::AliasedPointer:
270 case spirv::Decoration::Flat:
271 case spirv::Decoration::NonReadable:
272 case spirv::Decoration::NonWritable:
273 case spirv::Decoration::NoPerspective:
274 case spirv::Decoration::NoSignedWrap:
275 case spirv::Decoration::NoUnsignedWrap:
276 case spirv::Decoration::RelaxedPrecision:
277 case spirv::Decoration::Restrict:
278 case spirv::Decoration::RestrictPointer:
279 case spirv::Decoration::NoContraction:
282 if (isa<UnitAttr, DecorationAttr>(attr))
285 "expected unit attribute or decoration attribute for ")
286 << stringifyDecoration(decoration);
288 return emitError(loc,
"unhandled decoration ")
289 << stringifyDecoration(decoration);
291 return emitDecoration(resultID, decoration, args);
296 StringRef attrName = attr.
getName().strref();
298 std::optional<Decoration> decoration =
299 spirv::symbolizeDecoration(decorationName);
302 loc,
"non-argument attributes expected to have snake-case-ified "
303 "decoration name, unhandled attribute with name : ")
306 return processDecorationAttr(loc, resultID, *decoration, attr.
getValue());
309 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
310 assert(!name.empty() &&
"unexpected empty string for OpName");
315 nameOperands.push_back(resultID);
326 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
336 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
346 static_cast<uint32_t
>(memberDecoration.
decoration)});
361 bool Serializer::isInterfaceStructPtrType(
Type type)
const {
362 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
363 switch (ptrType.getStorageClass()) {
364 case spirv::StorageClass::PhysicalStorageBuffer:
365 case spirv::StorageClass::PushConstant:
366 case spirv::StorageClass::StorageBuffer:
367 case spirv::StorageClass::Uniform:
368 return isa<spirv::StructType>(ptrType.getPointeeType());
381 return processTypeImpl(loc, type, typeID, serializationCtx);
385 Serializer::processTypeImpl(
Location loc,
Type type, uint32_t &typeID,
387 typeID = getTypeID(type);
391 typeID = getNextID();
394 operands.push_back(typeID);
395 auto typeEnum = spirv::Opcode::OpTypeVoid;
396 bool deferSerialization =
false;
398 if ((isa<FunctionType>(type) &&
399 succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
401 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
402 deferSerialization, serializationCtx))) {
403 if (deferSerialization)
406 typeIDMap[type] = typeID;
410 if (recursiveStructInfos.count(type) != 0) {
413 for (
auto &ptrInfo : recursiveStructInfos[type]) {
417 ptrOperands.push_back(ptrInfo.pointerTypeID);
418 ptrOperands.push_back(
static_cast<uint32_t
>(ptrInfo.storageClass));
419 ptrOperands.push_back(typeIDMap[type]);
425 recursiveStructInfos[type].clear();
435 Location loc,
Type type, uint32_t resultID, spirv::Opcode &typeEnum,
438 deferSerialization =
false;
440 if (isVoidType(type)) {
441 typeEnum = spirv::Opcode::OpTypeVoid;
445 if (
auto intType = dyn_cast<IntegerType>(type)) {
446 if (intType.getWidth() == 1) {
447 typeEnum = spirv::Opcode::OpTypeBool;
451 typeEnum = spirv::Opcode::OpTypeInt;
452 operands.push_back(intType.getWidth());
457 operands.push_back(intType.isSigned() ? 1 : 0);
461 if (
auto floatType = dyn_cast<FloatType>(type)) {
462 typeEnum = spirv::Opcode::OpTypeFloat;
463 operands.push_back(floatType.getWidth());
467 if (
auto vectorType = dyn_cast<VectorType>(type)) {
468 uint32_t elementTypeID = 0;
469 if (
failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
470 serializationCtx))) {
473 typeEnum = spirv::Opcode::OpTypeVector;
474 operands.push_back(elementTypeID);
475 operands.push_back(vectorType.getNumElements());
479 if (
auto imageType = dyn_cast<spirv::ImageType>(type)) {
480 typeEnum = spirv::Opcode::OpTypeImage;
481 uint32_t sampledTypeID = 0;
482 if (
failed(processType(loc, imageType.getElementType(), sampledTypeID)))
485 llvm::append_values(operands, sampledTypeID,
486 static_cast<uint32_t
>(imageType.getDim()),
487 static_cast<uint32_t
>(imageType.getDepthInfo()),
488 static_cast<uint32_t
>(imageType.getArrayedInfo()),
489 static_cast<uint32_t
>(imageType.getSamplingInfo()),
490 static_cast<uint32_t
>(imageType.getSamplerUseInfo()),
491 static_cast<uint32_t
>(imageType.getImageFormat()));
495 if (
auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
496 typeEnum = spirv::Opcode::OpTypeArray;
497 uint32_t elementTypeID = 0;
498 if (
failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
499 serializationCtx))) {
502 operands.push_back(elementTypeID);
503 if (
auto elementCountID = prepareConstantInt(
505 operands.push_back(elementCountID);
507 return processTypeDecoration(loc, arrayType, resultID);
510 if (
auto ptrType = dyn_cast<spirv::PointerType>(type)) {
511 uint32_t pointeeTypeID = 0;
513 dyn_cast<spirv::StructType>(ptrType.getPointeeType());
516 serializationCtx.count(pointeeStruct.
getIdentifier()) != 0) {
522 forwardPtrOperands.push_back(resultID);
523 forwardPtrOperands.push_back(
524 static_cast<uint32_t
>(ptrType.getStorageClass()));
527 spirv::Opcode::OpTypeForwardPointer,
539 deferSerialization =
true;
543 recursiveStructInfos[structType].push_back(
544 {resultID, ptrType.getStorageClass()});
546 if (
failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
551 typeEnum = spirv::Opcode::OpTypePointer;
552 operands.push_back(
static_cast<uint32_t
>(ptrType.getStorageClass()));
553 operands.push_back(pointeeTypeID);
555 if (isInterfaceStructPtrType(ptrType)) {
556 if (
failed(emitDecoration(getTypeID(pointeeStruct),
557 spirv::Decoration::Block)))
558 return emitError(loc,
"cannot decorate ")
559 << pointeeStruct <<
" with Block decoration";
565 if (
auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
566 uint32_t elementTypeID = 0;
567 if (
failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
568 elementTypeID, serializationCtx))) {
571 typeEnum = spirv::Opcode::OpTypeRuntimeArray;
572 operands.push_back(elementTypeID);
573 return processTypeDecoration(loc, runtimeArrayType, resultID);
576 if (
auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
577 typeEnum = spirv::Opcode::OpTypeSampledImage;
578 uint32_t imageTypeID = 0;
580 processType(loc, sampledImageType.getImageType(), imageTypeID))) {
583 operands.push_back(imageTypeID);
587 if (
auto structType = dyn_cast<spirv::StructType>(type)) {
588 if (structType.isIdentified()) {
589 if (
failed(processName(resultID, structType.getIdentifier())))
591 serializationCtx.insert(structType.getIdentifier());
594 bool hasOffset = structType.hasOffset();
595 for (
auto elementIndex :
596 llvm::seq<uint32_t>(0, structType.getNumElements())) {
597 uint32_t elementTypeID = 0;
598 if (
failed(processTypeImpl(loc, structType.getElementType(elementIndex),
599 elementTypeID, serializationCtx))) {
602 operands.push_back(elementTypeID);
606 elementIndex, 1, spirv::Decoration::Offset,
607 static_cast<uint32_t
>(structType.getMemberOffset(elementIndex))};
608 if (
failed(processMemberDecoration(resultID, offsetDecoration))) {
609 return emitError(loc,
"cannot decorate ")
610 << elementIndex <<
"-th member of " << structType
611 <<
" with its offset";
616 structType.getMemberDecorations(memberDecorations);
618 for (
auto &memberDecoration : memberDecorations) {
619 if (
failed(processMemberDecoration(resultID, memberDecoration))) {
620 return emitError(loc,
"cannot decorate ")
621 <<
static_cast<uint32_t
>(memberDecoration.
memberIndex)
622 <<
"-th member of " << structType <<
" with "
623 << stringifyDecoration(memberDecoration.
decoration);
627 typeEnum = spirv::Opcode::OpTypeStruct;
629 if (structType.isIdentified())
630 serializationCtx.remove(structType.getIdentifier());
635 if (
auto cooperativeMatrixType =
636 dyn_cast<spirv::CooperativeMatrixType>(type)) {
637 uint32_t elementTypeID = 0;
638 if (
failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
639 elementTypeID, serializationCtx))) {
642 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
643 auto getConstantOp = [&](uint32_t id) {
645 return prepareConstantInt(loc, attr);
648 operands, elementTypeID,
649 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getScope())),
650 getConstantOp(cooperativeMatrixType.getRows()),
651 getConstantOp(cooperativeMatrixType.getColumns()),
652 getConstantOp(
static_cast<uint32_t
>(cooperativeMatrixType.getUse())));
656 if (
auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
657 uint32_t elementTypeID = 0;
658 if (
failed(processTypeImpl(loc, jointMatrixType.getElementType(),
659 elementTypeID, serializationCtx))) {
662 typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
663 auto getConstantOp = [&](uint32_t id) {
665 return prepareConstantInt(loc, attr);
668 operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
669 getConstantOp(jointMatrixType.getColumns()),
670 getConstantOp(
static_cast<uint32_t
>(jointMatrixType.getMatrixLayout())),
671 getConstantOp(
static_cast<uint32_t
>(jointMatrixType.getScope())));
675 if (
auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
676 uint32_t elementTypeID = 0;
677 if (
failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
678 serializationCtx))) {
681 typeEnum = spirv::Opcode::OpTypeMatrix;
682 llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
687 return emitError(loc,
"unhandled type in serialization: ") << type;
691 Serializer::prepareFunctionType(
Location loc, FunctionType type,
692 spirv::Opcode &typeEnum,
694 typeEnum = spirv::Opcode::OpTypeFunction;
695 assert(type.getNumResults() <= 1 &&
696 "serialization supports only a single return value");
697 uint32_t resultID = 0;
699 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
703 operands.push_back(resultID);
704 for (
auto &res : type.getInputs()) {
705 uint32_t argTypeID = 0;
706 if (
failed(processType(loc, res, argTypeID))) {
709 operands.push_back(argTypeID);
718 uint32_t Serializer::prepareConstant(
Location loc,
Type constType,
720 if (
auto id = prepareConstantScalar(loc, valueAttr)) {
727 if (
auto id = getConstantID(valueAttr)) {
732 if (
failed(processType(loc, constType, typeID))) {
736 uint32_t resultID = 0;
737 if (
auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
738 int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
740 resultID = prepareDenseElementsConstant(loc, constType, attr,
742 }
else if (
auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
743 resultID = prepareArrayConstant(loc, constType, arrayAttr);
747 emitError(loc,
"cannot serialize attribute: ") << valueAttr;
751 constIDMap[valueAttr] = resultID;
755 uint32_t Serializer::prepareArrayConstant(
Location loc,
Type constType,
758 if (
failed(processType(loc, constType, typeID))) {
762 uint32_t resultID = getNextID();
764 operands.reserve(attr.size() + 2);
765 auto elementType = cast<spirv::ArrayType>(constType).getElementType();
767 if (
auto elementID = prepareConstant(loc, elementType, elementAttr)) {
768 operands.push_back(elementID);
773 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
782 Serializer::prepareDenseElementsConstant(
Location loc,
Type constType,
785 auto shapedType = dyn_cast<ShapedType>(valueAttr.
getType());
786 assert(dim <= shapedType.getRank());
787 if (shapedType.getRank() == dim) {
788 if (
auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
789 return attr.getType().getElementType().isInteger(1)
790 ? prepareConstantBool(loc, attr.getValues<
BoolAttr>()[index])
791 : prepareConstantInt(loc,
792 attr.getValues<IntegerAttr>()[index]);
794 if (
auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
795 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
801 if (
failed(processType(loc, constType, typeID))) {
805 uint32_t resultID = getNextID();
807 operands.reserve(shapedType.getDimSize(dim) + 2);
808 auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
809 for (
int i = 0; i < shapedType.getDimSize(dim); ++i) {
811 if (
auto elementID = prepareDenseElementsConstant(
812 loc, elementType, valueAttr, dim + 1, index)) {
813 operands.push_back(elementID);
818 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
826 if (
auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
827 return prepareConstantFp(loc, floatAttr, isSpec);
829 if (
auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
830 return prepareConstantBool(loc, boolAttr, isSpec);
832 if (
auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
833 return prepareConstantInt(loc, intAttr, isSpec);
843 if (
auto id = getConstantID(boolAttr)) {
850 if (
failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
854 auto resultID = getNextID();
856 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue
857 : spirv::Opcode::OpConstantTrue)
858 : (isSpec ? spirv::Opcode::OpSpecConstantFalse
859 : spirv::Opcode::OpConstantFalse);
863 constIDMap[boolAttr] = resultID;
868 uint32_t Serializer::prepareConstantInt(
Location loc, IntegerAttr intAttr,
872 if (
auto id = getConstantID(intAttr)) {
879 if (
failed(processType(loc, intAttr.getType(), typeID))) {
883 auto resultID = getNextID();
884 APInt value = intAttr.getValue();
885 unsigned bitwidth = value.getBitWidth();
886 bool isSigned = intAttr.getType().isSignedInteger();
888 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
901 word =
static_cast<int32_t
>(value.getSExtValue());
903 word =
static_cast<uint32_t
>(value.getZExtValue());
915 words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
917 words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
920 {typeID, resultID, words.word1, words.word2});
923 std::string valueStr;
924 llvm::raw_string_ostream rss(valueStr);
925 value.print(rss,
false);
928 << bitwidth <<
"-bit integer literal: " << rss.str();
934 constIDMap[intAttr] = resultID;
939 uint32_t Serializer::prepareConstantFp(
Location loc, FloatAttr floatAttr,
943 if (
auto id = getConstantID(floatAttr)) {
950 if (
failed(processType(loc, floatAttr.getType(), typeID))) {
954 auto resultID = getNextID();
955 APFloat value = floatAttr.getValue();
956 APInt intValue = value.bitcastToAPInt();
959 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
961 if (&value.getSemantics() == &APFloat::IEEEsingle()) {
962 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
964 }
else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
968 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
970 {typeID, resultID, words.word1, words.word2});
971 }
else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
973 static_cast<uint32_t
>(value.bitcastToAPInt().getZExtValue());
976 std::string valueStr;
977 llvm::raw_string_ostream rss(valueStr);
981 << floatAttr.getType() <<
"-typed float literal: " << rss.str();
986 constIDMap[floatAttr] = resultID;
995 uint32_t Serializer::getOrCreateBlockID(
Block *block) {
996 if (uint32_t
id = getBlockID(block))
998 return blockIDMap[block] = getNextID();
1002 void Serializer::printBlock(
Block *block, raw_ostream &os) {
1003 os <<
"block " << block <<
" (id = ";
1004 if (uint32_t
id = getBlockID(block))
1013 Serializer::processBlock(
Block *block,
bool omitLabel,
1015 LLVM_DEBUG(llvm::dbgs() <<
"processing block " << block <<
":\n");
1016 LLVM_DEBUG(block->
print(llvm::dbgs()));
1017 LLVM_DEBUG(llvm::dbgs() <<
'\n');
1019 uint32_t blockID = getOrCreateBlockID(block);
1020 LLVM_DEBUG(printBlock(block, llvm::dbgs()));
1027 if (
failed(emitPhiForBlockArguments(block)))
1037 llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
1040 emitMerge =
nullptr;
1043 uint32_t blockID = getNextID();
1049 for (
Operation &op : llvm::drop_end(*block)) {
1050 if (
failed(processOperation(&op)))
1058 if (
failed(processOperation(&block->
back())))
1070 LLVM_DEBUG(llvm::dbgs() <<
"emitting phi instructions..\n");
1079 auto *terminator = mlirPredecessor->getTerminator();
1080 LLVM_DEBUG(llvm::dbgs() <<
" mlir predecessor ");
1081 LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
1082 LLVM_DEBUG(llvm::dbgs() <<
" terminator: " << *terminator <<
"\n");
1091 LLVM_DEBUG(llvm::dbgs() <<
" spirv predecessor ");
1092 LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
1093 if (
auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
1094 predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
1095 }
else if (
auto branchCondOp =
1096 dyn_cast<spirv::BranchConditionalOp>(terminator)) {
1097 std::optional<OperandRange> blockOperands;
1098 if (branchCondOp.getTrueTarget() == block) {
1099 blockOperands = branchCondOp.getTrueTargetOperands();
1101 assert(branchCondOp.getFalseTarget() == block);
1102 blockOperands = branchCondOp.getFalseTargetOperands();
1105 assert(!blockOperands->empty() &&
1106 "expected non-empty block operand range");
1107 predecessors.emplace_back(spirvPredecessor, *blockOperands);
1109 return terminator->emitError(
"unimplemented terminator for Phi creation");
1112 llvm::dbgs() <<
" block arguments:\n";
1113 for (
Value v : predecessors.back().second)
1114 llvm::dbgs() <<
" " << v <<
"\n";
1119 for (
auto argIndex : llvm::seq<unsigned>(0, block->
getNumArguments())) {
1123 uint32_t phiTypeID = 0;
1126 uint32_t phiID = getNextID();
1128 LLVM_DEBUG(llvm::dbgs() <<
"[phi] for block argument #" << argIndex <<
' '
1129 << arg <<
" (id = " << phiID <<
")\n");
1133 phiArgs.push_back(phiTypeID);
1134 phiArgs.push_back(phiID);
1136 for (
auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
1137 Value value = predecessors[predIndex].second[argIndex];
1138 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
1139 LLVM_DEBUG(llvm::dbgs() <<
"[phi] use predecessor (id = " << predBlockId
1140 <<
") value " << value <<
' ');
1142 uint32_t valueId = getValueID(value);
1146 LLVM_DEBUG(llvm::dbgs() <<
"(need to fix)\n");
1147 deferredPhiValues[value].push_back(functionBody.size() + 1 +
1150 LLVM_DEBUG(llvm::dbgs() <<
"(id = " << valueId <<
")\n");
1152 phiArgs.push_back(valueId);
1154 phiArgs.push_back(predBlockId);
1158 valueIDMap[arg] = phiID;
1169 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
1172 auto &setID = extendedInstSetIDMap[extensionSetName];
1174 setID = getNextID();
1176 importOperands.push_back(setID);
1184 if (operands.size() < 2) {
1185 return op->
emitError(
"extended instructions must have a result encoding");
1188 extInstOperands.reserve(operands.size() + 2);
1189 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
1190 extInstOperands.push_back(setID);
1191 extInstOperands.push_back(extensionOpcode);
1192 extInstOperands.append(std::next(operands.begin(), 2), operands.end());
1199 LLVM_DEBUG(llvm::dbgs() <<
"[op] '" << opInst->
getName() <<
"'\n");
1204 .Case([&](spirv::AddressOfOp op) {
return processAddressOfOp(op); })
1205 .Case([&](spirv::BranchOp op) {
return processBranchOp(op); })
1206 .Case([&](spirv::BranchConditionalOp op) {
1207 return processBranchConditionalOp(op);
1209 .Case([&](spirv::ConstantOp op) {
return processConstantOp(op); })
1210 .Case([&](spirv::FuncOp op) {
return processFuncOp(op); })
1211 .Case([&](spirv::GlobalVariableOp op) {
1212 return processGlobalVariableOp(op);
1214 .Case([&](spirv::LoopOp op) {
return processLoopOp(op); })
1215 .Case([&](spirv::ReferenceOfOp op) {
return processReferenceOfOp(op); })
1216 .Case([&](spirv::SelectionOp op) {
return processSelectionOp(op); })
1217 .Case([&](spirv::SpecConstantOp op) {
return processSpecConstantOp(op); })
1218 .Case([&](spirv::SpecConstantCompositeOp op) {
1219 return processSpecConstantCompositeOp(op);
1221 .Case([&](spirv::SpecConstantOperationOp op) {
1222 return processSpecConstantOperationOp(op);
1224 .Case([&](spirv::UndefOp op) {
return processUndefOp(op); })
1225 .Case([&](spirv::VariableOp op) {
return processVariableOp(op); })
1230 [&](
Operation *op) {
return dispatchToAutogenSerialization(op); });
1234 StringRef extInstSet,
1239 uint32_t resultID = 0;
1241 uint32_t resultTypeID = 0;
1244 operands.push_back(resultTypeID);
1246 resultID = getNextID();
1247 operands.push_back(resultID);
1248 valueIDMap[op->
getResult(0)] = resultID;
1252 operands.push_back(getValueID(operand));
1254 if (
failed(emitDebugLine(functionBody, loc)))
1257 if (extInstSet.empty()) {
1261 if (
failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
1267 if (
failed(processDecoration(loc, resultID, attr)))
1276 spirv::Decoration decoration,
1278 uint32_t wordCount = 3 + params.size();
1279 llvm::append_values(
1282 static_cast<uint32_t
>(decoration));
1283 llvm::append_range(decorations, params);
1292 if (lastProcessedWasMergeInst) {
1293 lastProcessedWasMergeInst =
false;
1297 auto fileLoc = dyn_cast<FileLineColLoc>(loc);
1300 {fileID, fileLoc.getLine(), fileLoc.getColumn()});
static MLIRContext * getContext(OpFoldResult val)
static llvm::ManagedStatic< PassManagerOptions > options
static Block * getStructuredControlFlowOpMergeBlock(Operation *op)
Returns the merge block if the given op is a structured control flow op.
static Block * getPhiIncomingBlock(Block *block)
Given a predecessor block for a block with arguments, returns the block that should be used as the pa...
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< pred_iterator > getPredecessors()
OpListType & getOperations()
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
bool getValue() const
Return the boolean value of this attribute.
IntegerAttr getI32IntegerAttr(int32_t value)
An attribute that represents a reference to a dense vector or tensor object.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Block * getBlock()
Returns the operation block that contains this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
unsigned getNumResults()
Return the number of results held by this operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
unsigned getArrayStride() const
Returns the array stride in bytes.
unsigned getArrayStride() const
Returns the array stride in bytes.
void printValueIDMap(raw_ostream &os)
(For debugging) prints each value and its corresponding result <id>.
Serializer(spirv::ModuleOp module, const SerializationOptions &options)
Creates a serializer for the given SPIR-V module.
LogicalResult serialize()
Serializes the remembered SPIR-V module.
void collect(SmallVectorImpl< uint32_t > &binary)
Collects the final SPIR-V binary.
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
bool isIdentified() const
Returns true if the StructType is identified.
StringRef getIdentifier() const
For literal structs, return an empty string.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode)
Returns the word-count-prefixed opcode for an SPIR-V instruction.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
void appendModuleHeader(SmallVectorImpl< uint32_t > &header, spirv::Version version, uint32_t idBound)
Appends a SPRI-V module header to header with the given version and idBound.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
static std::string getDecorationName(StringRef attrName)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
bool emitSymbolName
Whether to emit OpName instructions for SPIR-V symbol ops.
bool emitDebugInfo
Whether to emit OpLine location information for SPIR-V ops.