24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "spirv-deserialization"
45 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion())
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (
failed(sliceInstruction(opcode, operands)))
82 if (
failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (
failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158 if (operands.size() != 1)
159 return emitError(unknownLoc,
"OpMemoryModel must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
163 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
178 if (wordIndex != words.size())
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
183 return emitError(unknownLoc,
"unknown extension: ") << extName;
185 extensions.insert(*ext);
191 if (words.size() < 2) {
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
199 if (wordIndex != words.size()) {
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
210 extensions.getArrayRef(), context));
215 if (operands.size() != 2)
216 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel
>(operands.front())));
222 (*module)->setAttr(
"memory_model",
223 opBuilder.getAttr<spirv::MemoryModelAttr>(
224 static_cast<spirv::MemoryModel
>(operands.back())));
233 if (words.size() < 2) {
235 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
237 auto decorationName =
238 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
239 if (decorationName.empty()) {
240 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
242 auto attrName = llvm::convertToSnakeFromCamelCase(decorationName);
243 auto symbol = opBuilder.getStringAttr(attrName);
244 switch (
static_cast<spirv::Decoration
>(words[1])) {
245 case spirv::Decoration::FPFastMathMode:
246 if (words.size() != 3) {
247 return emitError(unknownLoc,
"OpDecorate with ")
248 << decorationName <<
" needs a single integer literal";
250 decorations[words[0]].set(
252 static_cast<FPFastMathMode
>(words[2])));
254 case spirv::Decoration::DescriptorSet:
255 case spirv::Decoration::Binding:
256 if (words.size() != 3) {
257 return emitError(unknownLoc,
"OpDecorate with ")
258 << decorationName <<
" needs a single integer literal";
260 decorations[words[0]].set(
261 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
263 case spirv::Decoration::BuiltIn:
264 if (words.size() != 3) {
265 return emitError(unknownLoc,
"OpDecorate with ")
266 << decorationName <<
" needs a single integer literal";
268 decorations[words[0]].set(
269 symbol, opBuilder.getStringAttr(
270 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
272 case spirv::Decoration::ArrayStride:
273 if (words.size() != 3) {
274 return emitError(unknownLoc,
"OpDecorate with ")
275 << decorationName <<
" needs a single integer literal";
277 typeDecorations[words[0]] = words[2];
279 case spirv::Decoration::LinkageAttributes: {
280 if (words.size() < 4) {
281 return emitError(unknownLoc,
"OpDecorate with ")
283 <<
" needs at least 1 string and 1 integer literal";
291 unsigned wordIndex = 2;
293 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
295 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
296 linkageName, linkageTypeAttr);
297 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
300 case spirv::Decoration::Aliased:
301 case spirv::Decoration::Block:
302 case spirv::Decoration::BufferBlock:
303 case spirv::Decoration::Flat:
304 case spirv::Decoration::NonReadable:
305 case spirv::Decoration::NonWritable:
306 case spirv::Decoration::NoPerspective:
307 case spirv::Decoration::NoSignedWrap:
308 case spirv::Decoration::NoUnsignedWrap:
309 case spirv::Decoration::RelaxedPrecision:
310 case spirv::Decoration::Restrict:
311 if (words.size() != 2) {
312 return emitError(unknownLoc,
"OpDecoration with ")
313 << decorationName <<
"needs a single target <id>";
319 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
321 case spirv::Decoration::Location:
322 case spirv::Decoration::SpecId:
323 if (words.size() != 3) {
324 return emitError(unknownLoc,
"OpDecoration with ")
325 << decorationName <<
"needs a single integer literal";
327 decorations[words[0]].set(
328 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
331 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
339 if (words.size() < 3) {
341 "OpMemberDecorate must have at least 3 operands");
344 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
345 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
347 " missing offset specification in OpMemberDecorate with "
348 "Offset decoration");
351 if (words.size() > 3) {
352 decorationOperands = words.slice(3);
354 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
359 if (words.size() < 3) {
360 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
362 unsigned wordIndex = 2;
364 if (wordIndex != words.size()) {
366 "unexpected trailing words in OpMemberName instruction");
368 memberNameMap[words[0]][words[1]] = name;
375 return emitError(unknownLoc,
"found function inside function");
379 if (operands.size() != 4) {
380 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
382 Type resultType = getType(operands[0]);
384 return emitError(unknownLoc,
"undefined result type from <id> ")
388 uint32_t fnID = operands[1];
389 if (funcMap.count(fnID)) {
390 return emitError(unknownLoc,
"duplicate function definition/declaration");
393 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
395 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
398 Type fnType = getType(operands[3]);
399 if (!fnType || !isa<FunctionType>(fnType)) {
400 return emitError(unknownLoc,
"unknown function type from <id> ")
403 auto functionType = cast<FunctionType>(fnType);
405 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
406 (functionType.getNumResults() == 1 &&
407 functionType.getResult(0) != resultType)) {
408 return emitError(unknownLoc,
"mismatch in function type ")
409 << functionType <<
" and return type " << resultType <<
" specified";
412 std::string fnName = getFunctionSymbol(fnID);
413 auto funcOp = opBuilder.create<spirv::FuncOp>(
414 unknownLoc, fnName, functionType, fnControl.value());
416 if (decorations.count(fnID)) {
417 for (
auto attr : decorations[fnID].getAttrs()) {
418 funcOp->setAttr(attr.getName(), attr.getValue());
421 curFunction = funcMap[fnID] = funcOp;
422 auto *entryBlock = funcOp.addEntryBlock();
425 <<
"//===-------------------------------------------===//\n";
426 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
427 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
428 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
429 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
434 if (functionType.getNumInputs()) {
435 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
436 auto argType = functionType.getInput(i);
437 spirv::Opcode opcode = spirv::Opcode::OpNop;
439 if (
failed(sliceInstruction(opcode, operands,
440 spirv::Opcode::OpFunctionParameter))) {
443 if (opcode != spirv::Opcode::OpFunctionParameter) {
446 "missing OpFunctionParameter instruction for argument ")
449 if (operands.size() != 2) {
452 "expected result type and result <id> for OpFunctionParameter");
454 auto argDefinedType = getType(operands[0]);
455 if (!argDefinedType || argDefinedType != argType) {
457 "mismatch in argument type between function type "
459 << functionType <<
" and argument type definition "
460 << argDefinedType <<
" at argument " << i;
462 if (getValue(operands[1])) {
463 return emitError(unknownLoc,
"duplicate definition of result <id> ")
466 auto argValue = funcOp.getArgument(i);
467 valueMap[operands[1]] = argValue;
474 auto linkageAttr = funcOp.getLinkageAttributes();
475 auto hasImportLinkage =
476 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
477 spirv::LinkageType::Import);
478 if (hasImportLinkage)
485 spirv::Opcode opcode = spirv::Opcode::OpNop;
493 if (
failed(sliceInstruction(opcode, instOperands,
494 spirv::Opcode::OpFunctionEnd))) {
497 if (opcode == spirv::Opcode::OpFunctionEnd) {
498 return processFunctionEnd(instOperands);
500 if (opcode != spirv::Opcode::OpLabel) {
501 return emitError(unknownLoc,
"a basic block must start with OpLabel");
503 if (instOperands.size() != 1) {
504 return emitError(unknownLoc,
"OpLabel should only have result <id>");
506 blockMap[instOperands[0]] = entryBlock;
507 if (
failed(processLabel(instOperands))) {
513 while (
succeeded(sliceInstruction(opcode, instOperands,
514 spirv::Opcode::OpFunctionEnd)) &&
515 opcode != spirv::Opcode::OpFunctionEnd) {
516 if (
failed(processInstruction(opcode, instOperands))) {
520 if (opcode != spirv::Opcode::OpFunctionEnd) {
524 return processFunctionEnd(instOperands);
530 if (!operands.empty()) {
531 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
537 if (
failed(wireUpBlockArgument()) ||
failed(structurizeControlFlow())) {
542 curFunction = std::nullopt;
547 <<
"//===-------------------------------------------===//\n";
552 std::optional<std::pair<Attribute, Type>>
553 spirv::Deserializer::getConstant(uint32_t
id) {
554 auto constIt = constantMap.find(
id);
555 if (constIt == constantMap.end())
557 return constIt->getSecond();
560 std::optional<spirv::SpecConstOperationMaterializationInfo>
561 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
562 auto constIt = specConstOperationMap.find(
id);
563 if (constIt == specConstOperationMap.end())
565 return constIt->getSecond();
568 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
569 auto funcName = nameMap.lookup(
id).str();
570 if (funcName.empty()) {
571 funcName =
"spirv_fn_" + std::to_string(
id);
576 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
577 auto constName = nameMap.lookup(
id).str();
578 if (constName.empty()) {
579 constName =
"spirv_spec_const_" + std::to_string(
id);
584 spirv::SpecConstantOp
585 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
586 TypedAttr defaultValue) {
587 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
588 auto op = opBuilder.
create<spirv::SpecConstantOp>(unknownLoc, symName,
590 if (decorations.count(resultID)) {
591 for (
auto attr : decorations[resultID].getAttrs())
592 op->
setAttr(attr.getName(), attr.getValue());
594 specConstMap[resultID] = op;
600 unsigned wordIndex = 0;
601 if (operands.size() < 3) {
604 "OpVariable needs at least 3 operands, type, <id> and storage class");
608 auto type = getType(operands[wordIndex]);
610 return emitError(unknownLoc,
"unknown result type <id> : ")
611 << operands[wordIndex];
613 auto ptrType = dyn_cast<spirv::PointerType>(type);
616 "expected a result type <id> to be a spirv.ptr, found : ")
622 auto variableID = operands[wordIndex];
623 auto variableName = nameMap.lookup(variableID).str();
624 if (variableName.empty()) {
625 variableName =
"spirv_var_" + std::to_string(variableID);
630 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
631 if (ptrType.getStorageClass() != storageClass) {
632 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
633 << type <<
" and that specified in OpVariable instruction : "
634 << stringifyStorageClass(storageClass);
640 if (wordIndex < operands.size()) {
641 auto initializerOp = getGlobalVariable(operands[wordIndex]);
642 if (!initializerOp) {
643 return emitError(unknownLoc,
"unknown <id> ")
644 << operands[wordIndex] <<
"used as initializer";
649 if (wordIndex != operands.size()) {
651 "found more operands than expected when deserializing "
652 "OpVariable instruction, only ")
653 << wordIndex <<
" of " << operands.size() <<
" processed";
655 auto loc = createFileLineColLoc(opBuilder);
656 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
657 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
661 if (decorations.count(variableID)) {
662 for (
auto attr : decorations[variableID].getAttrs())
663 varOp->setAttr(attr.getName(), attr.getValue());
665 globalVariableMap[variableID] = varOp;
669 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
670 auto constInfo = getConstant(
id);
674 return dyn_cast<IntegerAttr>(constInfo->first);
678 if (operands.size() < 2) {
679 return emitError(unknownLoc,
"OpName needs at least 2 operands");
681 if (!nameMap.lookup(operands[0]).empty()) {
682 return emitError(unknownLoc,
"duplicate name found for result <id> ")
685 unsigned wordIndex = 1;
687 if (wordIndex != operands.size()) {
689 "unexpected trailing words in OpName instruction");
691 nameMap[operands[0]] = name;
699 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
701 if (operands.empty()) {
702 return emitError(unknownLoc,
"type instruction with opcode ")
703 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
708 if (typeMap.count(operands[0])) {
709 return emitError(unknownLoc,
"duplicate definition for result <id> ")
714 case spirv::Opcode::OpTypeVoid:
715 if (operands.size() != 1)
716 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
717 typeMap[operands[0]] = opBuilder.getNoneType();
719 case spirv::Opcode::OpTypeBool:
720 if (operands.size() != 1)
721 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
722 typeMap[operands[0]] = opBuilder.getI1Type();
724 case spirv::Opcode::OpTypeInt: {
725 if (operands.size() != 3)
727 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
737 : IntegerType::SignednessSemantics::Signless;
740 case spirv::Opcode::OpTypeFloat: {
741 if (operands.size() != 2)
742 return emitError(unknownLoc,
"OpTypeFloat must have bitwidth parameter");
745 switch (operands[1]) {
747 floatTy = opBuilder.getF16Type();
750 floatTy = opBuilder.getF32Type();
753 floatTy = opBuilder.getF64Type();
756 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
759 typeMap[operands[0]] = floatTy;
761 case spirv::Opcode::OpTypeVector: {
762 if (operands.size() != 3) {
765 "OpTypeVector must have element type and count parameters");
767 Type elementTy = getType(operands[1]);
769 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
774 case spirv::Opcode::OpTypePointer: {
775 return processOpTypePointer(operands);
777 case spirv::Opcode::OpTypeArray:
778 return processArrayType(operands);
779 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
780 return processCooperativeMatrixTypeKHR(operands);
781 case spirv::Opcode::OpTypeCooperativeMatrixNV:
782 return processCooperativeMatrixTypeNV(operands);
783 case spirv::Opcode::OpTypeFunction:
784 return processFunctionType(operands);
785 case spirv::Opcode::OpTypeJointMatrixINTEL:
786 return processJointMatrixType(operands);
787 case spirv::Opcode::OpTypeImage:
788 return processImageType(operands);
789 case spirv::Opcode::OpTypeSampledImage:
790 return processSampledImageType(operands);
791 case spirv::Opcode::OpTypeRuntimeArray:
792 return processRuntimeArrayType(operands);
793 case spirv::Opcode::OpTypeStruct:
794 return processStructType(operands);
795 case spirv::Opcode::OpTypeMatrix:
796 return processMatrixType(operands);
798 return emitError(unknownLoc,
"unhandled type instruction");
805 if (operands.size() != 3)
806 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
808 auto pointeeType = getType(operands[2]);
810 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
813 uint32_t typePointerID = operands[0];
814 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
817 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
818 deferredStructIt != std::end(deferredStructTypesInfos);) {
819 for (
auto *unresolvedMemberIt =
820 std::begin(deferredStructIt->unresolvedMemberTypes);
821 unresolvedMemberIt !=
822 std::end(deferredStructIt->unresolvedMemberTypes);) {
823 if (unresolvedMemberIt->first == typePointerID) {
827 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
828 typeMap[typePointerID];
830 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
832 ++unresolvedMemberIt;
836 if (deferredStructIt->unresolvedMemberTypes.empty()) {
838 auto structType = deferredStructIt->deferredStructType;
840 assert(structType &&
"expected a spirv::StructType");
841 assert(structType.isIdentified() &&
"expected an indentified struct");
843 if (
failed(structType.trySetBody(
844 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
845 deferredStructIt->memberDecorationsInfo)))
848 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
859 if (operands.size() != 3) {
861 "OpTypeArray must have element type and count parameters");
864 Type elementTy = getType(operands[1]);
866 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
872 auto countInfo = getConstant(operands[2]);
874 return emitError(unknownLoc,
"OpTypeArray count <id> ")
875 << operands[2] <<
"can only come from normal constant right now";
878 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
879 count = intVal.getValue().getZExtValue();
881 return emitError(unknownLoc,
"OpTypeArray count must come from a "
882 "scalar integer constant instruction");
886 elementTy, count, typeDecorations.lookup(operands[0]));
892 assert(!operands.empty() &&
"No operands for processing function type");
893 if (operands.size() == 1) {
894 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
896 auto returnType = getType(operands[1]);
898 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
901 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
902 auto ty = getType(operands[i]);
904 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
906 argTypes.push_back(ty);
909 if (!isVoidType(returnType)) {
916 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
918 if (operands.size() != 6) {
920 "OpTypeCooperativeMatrixKHR must have element type, "
921 "scope, row and column parameters, and use");
924 Type elementTy = getType(operands[1]);
927 "OpTypeCooperativeMatrixKHR references undefined <id> ")
931 std::optional<spirv::Scope> scope =
932 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
936 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
940 unsigned rows = getConstantInt(operands[3]).getInt();
941 unsigned columns = getConstantInt(operands[4]).getInt();
943 std::optional<spirv::CooperativeMatrixUseKHR> use =
944 spirv::symbolizeCooperativeMatrixUseKHR(
945 getConstantInt(operands[5]).getInt());
949 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
953 typeMap[operands[0]] =
958 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeNV(
960 if (operands.size() != 5) {
961 return emitError(unknownLoc,
"OpTypeCooperativeMatrixNV must have element "
962 "type and row x column parameters");
965 Type elementTy = getType(operands[1]);
968 "OpTypeCooperativeMatrixNV references undefined <id> ")
972 std::optional<spirv::Scope> scope =
973 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
977 "OpTypeCooperativeMatrixNV references undefined scope <id> ")
981 unsigned rows = getConstantInt(operands[3]).getInt();
982 unsigned columns = getConstantInt(operands[4]).getInt();
984 typeMap[operands[0]] =
991 if (operands.size() != 6) {
992 return emitError(unknownLoc,
"OpTypeJointMatrix must have element "
993 "type and row x column parameters");
996 Type elementTy = getType(operands[1]);
998 return emitError(unknownLoc,
"OpTypeJointMatrix references undefined <id> ")
1002 auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1005 "OpTypeJointMatrix references undefined scope <id> ")
1009 spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1010 if (!matrixLayout) {
1012 "OpTypeJointMatrix references undefined scope <id> ")
1015 unsigned rows = getConstantInt(operands[2]).getInt();
1016 unsigned columns = getConstantInt(operands[3]).getInt();
1019 elementTy, scope.value(), rows, columns, matrixLayout.value());
1025 if (operands.size() != 2) {
1026 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1028 Type memberType = getType(operands[1]);
1031 "OpTypeRuntimeArray references undefined <id> ")
1035 memberType, typeDecorations.lookup(operands[0]));
1043 if (operands.empty()) {
1044 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1047 if (operands.size() == 1) {
1049 typeMap[operands[0]] =
1058 for (
auto op : llvm::drop_begin(operands, 1)) {
1059 Type memberType = getType(op);
1060 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1062 if (!memberType && !typeForwardPtr)
1063 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1067 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1069 memberTypes.push_back(memberType);
1074 if (memberDecorationMap.count(operands[0])) {
1075 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1076 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1077 if (allMemberDecorations.count(memberIndex)) {
1078 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1080 if (memberDecoration.first == spirv::Decoration::Offset) {
1082 if (offsetInfo.empty()) {
1083 offsetInfo.resize(memberTypes.size());
1085 offsetInfo[memberIndex] = memberDecoration.second[0];
1087 if (!memberDecoration.second.empty()) {
1088 memberDecorationsInfo.emplace_back(memberIndex, 1,
1089 memberDecoration.first,
1090 memberDecoration.second[0]);
1092 memberDecorationsInfo.emplace_back(memberIndex, 0,
1093 memberDecoration.first, 0);
1101 uint32_t structID = operands[0];
1102 std::string structIdentifier = nameMap.lookup(structID).str();
1104 if (structIdentifier.empty()) {
1105 assert(unresolvedMemberTypes.empty() &&
1106 "didn't expect unresolved member types");
1111 typeMap[structID] = structTy;
1113 if (!unresolvedMemberTypes.empty())
1114 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1115 memberTypes, offsetInfo,
1116 memberDecorationsInfo});
1117 else if (
failed(structTy.trySetBody(memberTypes, offsetInfo,
1118 memberDecorationsInfo)))
1129 if (operands.size() != 3) {
1131 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1132 " (result_id, column_type, and column_count)");
1135 Type elementTy = getType(operands[1]);
1138 "OpTypeMatrix references undefined column type.")
1142 uint32_t colsCount = operands[2];
1149 if (operands.size() != 2)
1151 "OpTypeForwardPointer instruction must have two operands");
1153 typeForwardPointerIDs.insert(operands[0]);
1163 if (operands.size() != 8)
1166 "OpTypeImage with non-eight operands are not supported yet");
1168 Type elementTy = getType(operands[1]);
1170 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1173 auto dim = spirv::symbolizeDim(operands[2]);
1175 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1178 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1180 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1183 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1185 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1188 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1190 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1192 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1193 if (!samplerUseInfo)
1194 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1197 auto format = spirv::symbolizeImageFormat(operands[7]);
1199 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1203 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1204 samplingInfo.value(), samplerUseInfo.value(), format.value());
1210 if (operands.size() != 2)
1211 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1213 Type elementTy = getType(operands[1]);
1216 "OpTypeSampledImage references undefined <id>: ")
1229 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1231 if (operands.size() < 2) {
1233 << opname <<
" must have type <id> and result <id>";
1235 if (operands.size() < 3) {
1237 << opname <<
" must have at least 1 more parameter";
1240 Type resultType = getType(operands[0]);
1242 return emitError(unknownLoc,
"undefined result type from <id> ")
1246 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) ->
LogicalResult {
1247 if (bitwidth == 64) {
1248 if (operands.size() == 4) {
1252 << opname <<
" should have 2 parameters for 64-bit values";
1254 if (bitwidth <= 32) {
1255 if (operands.size() == 3) {
1261 <<
" should have 1 parameter for values with no more than 32 bits";
1263 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1267 auto resultID = operands[1];
1269 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1270 auto bitwidth = intType.getWidth();
1271 if (
failed(checkOperandSizeForBitwidth(bitwidth))) {
1276 if (bitwidth == 64) {
1283 } words = {operands[2], operands[3]};
1284 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1285 }
else if (bitwidth <= 32) {
1286 value = APInt(bitwidth, operands[2],
true);
1289 auto attr = opBuilder.getIntegerAttr(intType, value);
1292 createSpecConstant(unknownLoc, resultID, attr);
1296 constantMap.try_emplace(resultID, attr, intType);
1302 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1303 auto bitwidth = floatType.getWidth();
1304 if (
failed(checkOperandSizeForBitwidth(bitwidth))) {
1309 if (floatType.isF64()) {
1316 } words = {operands[2], operands[3]};
1317 value = APFloat(llvm::bit_cast<double>(words));
1318 }
else if (floatType.isF32()) {
1319 value = APFloat(llvm::bit_cast<float>(operands[2]));
1320 }
else if (floatType.isF16()) {
1321 APInt data(16, operands[2]);
1322 value = APFloat(APFloat::IEEEhalf(), data);
1325 auto attr = opBuilder.getFloatAttr(floatType, value);
1327 createSpecConstant(unknownLoc, resultID, attr);
1331 constantMap.try_emplace(resultID, attr, floatType);
1337 return emitError(unknownLoc,
"OpConstant can only generate values of "
1338 "scalar integer or floating-point type");
1343 if (operands.size() != 2) {
1345 << (isSpec ?
"Spec" :
"") <<
"Constant"
1346 << (isTrue ?
"True" :
"False")
1347 <<
" must have type <id> and result <id>";
1350 auto attr = opBuilder.getBoolAttr(isTrue);
1351 auto resultID = operands[1];
1353 createSpecConstant(unknownLoc, resultID, attr);
1357 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1365 if (operands.size() < 2) {
1367 "OpConstantComposite must have type <id> and result <id>");
1369 if (operands.size() < 3) {
1371 "OpConstantComposite must have at least 1 parameter");
1374 Type resultType = getType(operands[0]);
1376 return emitError(unknownLoc,
"undefined result type from <id> ")
1381 elements.reserve(operands.size() - 2);
1382 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1383 auto elementInfo = getConstant(operands[i]);
1385 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1386 << operands[i] <<
" must come from a normal constant";
1388 elements.push_back(elementInfo->first);
1391 auto resultID = operands[1];
1392 if (
auto vectorType = dyn_cast<VectorType>(resultType)) {
1396 constantMap.try_emplace(resultID, attr, resultType);
1397 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1398 auto attr = opBuilder.getArrayAttr(elements);
1399 constantMap.try_emplace(resultID, attr, resultType);
1401 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1410 if (operands.size() < 2) {
1412 "OpConstantComposite must have type <id> and result <id>");
1414 if (operands.size() < 3) {
1416 "OpConstantComposite must have at least 1 parameter");
1419 Type resultType = getType(operands[0]);
1421 return emitError(unknownLoc,
"undefined result type from <id> ")
1425 auto resultID = operands[1];
1426 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1429 elements.reserve(operands.size() - 2);
1430 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1431 auto elementInfo = getSpecConstant(operands[i]);
1435 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1437 opBuilder.getArrayAttr(elements));
1438 specConstCompositeMap[resultID] = op;
1445 if (operands.size() < 3)
1446 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1447 "result <id>, and operand opcode");
1449 uint32_t resultTypeID = operands[0];
1451 if (!getType(resultTypeID))
1452 return emitError(unknownLoc,
"undefined result type from <id> ")
1455 uint32_t resultID = operands[1];
1456 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1457 auto emplaceResult = specConstOperationMap.try_emplace(
1459 SpecConstOperationMaterializationInfo{
1460 enclosedOpcode, resultTypeID,
1463 if (!emplaceResult.second)
1464 return emitError(unknownLoc,
"value with <id>: ")
1465 << resultID <<
" is probably defined before.";
1470 Value spirv::Deserializer::materializeSpecConstantOperation(
1471 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1474 Type resultType = getType(resultTypeID);
1487 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1488 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1491 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1492 enclosedOpResultTypeAndOperands.push_back(fakeID);
1493 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1494 enclosedOpOperands.end());
1501 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1508 auto loc = createFileLineColLoc(opBuilder);
1509 auto specConstOperationOp =
1510 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1512 Region &body = specConstOperationOp.getBody();
1514 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1521 opBuilder.setInsertionPointToEnd(&block);
1523 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1524 return specConstOperationOp.getResult();
1529 if (operands.size() != 2) {
1531 "OpConstantNull must have type <id> and result <id>");
1534 Type resultType = getType(operands[0]);
1536 return emitError(unknownLoc,
"undefined result type from <id> ")
1540 auto resultID = operands[1];
1541 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1542 auto attr = opBuilder.getZeroAttr(resultType);
1545 constantMap.try_emplace(resultID, attr, resultType);
1549 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1557 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1558 if (
auto *block = getBlock(
id)) {
1559 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1560 <<
" @ " << block <<
"\n");
1567 auto *block = curFunction->addBlock();
1568 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1569 <<
" @ " << block <<
"\n");
1570 return blockMap[id] = block;
1575 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1578 if (operands.size() != 1) {
1579 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1582 auto *target = getOrCreateBlock(operands[0]);
1583 auto loc = createFileLineColLoc(opBuilder);
1587 opBuilder.create<spirv::BranchOp>(loc, target);
1597 "OpBranchConditional must appear inside a block");
1600 if (operands.size() != 3 && operands.size() != 5) {
1602 "OpBranchConditional must have condition, true label, "
1603 "false label, and optionally two branch weights");
1606 auto condition = getValue(operands[0]);
1607 auto *trueBlock = getOrCreateBlock(operands[1]);
1608 auto *falseBlock = getOrCreateBlock(operands[2]);
1610 std::optional<std::pair<uint32_t, uint32_t>> weights;
1611 if (operands.size() == 5) {
1612 weights = std::make_pair(operands[3], operands[4]);
1617 auto loc = createFileLineColLoc(opBuilder);
1618 opBuilder.create<spirv::BranchConditionalOp>(
1619 loc, condition, trueBlock,
1629 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1632 if (operands.size() != 1) {
1633 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1636 auto labelID = operands[0];
1638 auto *block = getOrCreateBlock(labelID);
1639 LLVM_DEBUG(logger.startLine()
1640 <<
"[block] populating block " << block <<
"\n");
1642 assert(block->
empty() &&
"re-deserialize the same block!");
1644 opBuilder.setInsertionPointToStart(block);
1645 blockMap[labelID] = curBlock = block;
1653 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1656 if (operands.size() < 2) {
1659 "OpSelectionMerge must specify merge target and selection control");
1662 auto *mergeBlock = getOrCreateBlock(operands[0]);
1663 auto loc = createFileLineColLoc(opBuilder);
1664 auto selectionControl = operands[1];
1666 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1670 "a block cannot have more than one OpSelectionMerge instruction");
1679 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1682 if (operands.size() < 3) {
1683 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1684 "continue target and loop control");
1687 auto *mergeBlock = getOrCreateBlock(operands[0]);
1688 auto *continueBlock = getOrCreateBlock(operands[1]);
1689 auto loc = createFileLineColLoc(opBuilder);
1690 uint32_t loopControl = operands[2];
1693 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1697 "a block cannot have more than one OpLoopMerge instruction");
1705 return emitError(unknownLoc,
"OpPhi must appear in a block");
1708 if (operands.size() < 4) {
1709 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1710 "and variable-parent pairs");
1714 Type blockArgType = getType(operands[0]);
1715 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1716 valueMap[operands[1]] = blockArg;
1717 LLVM_DEBUG(logger.startLine()
1718 <<
"[phi] created block argument " << blockArg
1719 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1723 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1724 uint32_t value = operands[i];
1725 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1726 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1727 blockPhiInfo[predecessorTargetPair].
push_back(value);
1728 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1729 <<
" with arg id = " << value <<
"\n");
1738 class ControlFlowStructurizer {
1741 ControlFlowStructurizer(
Location loc, uint32_t control,
1744 llvm::ScopedPrinter &logger)
1745 : location(loc), control(control), blockMergeInfo(mergeInfo),
1746 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1749 ControlFlowStructurizer(
Location loc, uint32_t control,
1752 : location(loc), control(control), blockMergeInfo(mergeInfo),
1753 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1768 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1771 spirv::LoopOp createLoopOp(uint32_t loopControl);
1774 void collectBlocksInConstruct();
1783 Block *continueBlock;
1789 llvm::ScopedPrinter &logger;
1795 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1798 OpBuilder builder(&mergeBlock->front());
1800 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1801 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1802 selectionOp.addMergeBlock();
1807 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1810 OpBuilder builder(&mergeBlock->front());
1812 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1813 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1814 loopOp.addEntryAndMergeBlock();
1819 void ControlFlowStructurizer::collectBlocksInConstruct() {
1820 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1823 constructBlocks.insert(headerBlock);
1827 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1828 for (
auto *successor : constructBlocks[i]->getSuccessors())
1829 if (successor != mergeBlock)
1830 constructBlocks.insert(successor);
1836 bool isLoop = continueBlock !=
nullptr;
1838 if (
auto loopOp = createLoopOp(control))
1839 op = loopOp.getOperation();
1841 if (
auto selectionOp = createSelectionOp(control))
1842 op = selectionOp.getOperation();
1851 mapper.
map(mergeBlock, &body.
back());
1853 collectBlocksInConstruct();
1875 for (
auto *block : constructBlocks) {
1878 auto *newBlock = builder.createBlock(&body.
back());
1879 mapper.
map(block, newBlock);
1880 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
1881 <<
" from block " << block <<
"\n");
1885 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1886 mapper.
map(blockArg, newArg);
1887 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
1888 << blockArg <<
" to " << newArg <<
"\n");
1891 LLVM_DEBUG(logger.startLine()
1892 <<
"[cf] block " << block <<
" is a function entry block\n");
1895 for (
auto &op : *block)
1896 newBlock->push_back(op.
clone(mapper));
1900 auto remapOperands = [&](
Operation *op) {
1903 operand.set(mappedOp);
1906 succOp.set(mappedOp);
1908 for (
auto &block : body)
1909 block.walk(remapOperands);
1917 headerBlock->replaceAllUsesWith(mergeBlock);
1920 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
1921 headerBlock->getParentOp()->
print(logger.getOStream());
1922 logger.startLine() <<
"\n";
1926 if (!mergeBlock->args_empty()) {
1927 return mergeBlock->getParentOp()->emitError(
1928 "OpPhi in loop merge block unsupported");
1935 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1940 if (!headerBlock->args_empty())
1941 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1945 builder.setInsertionPointToEnd(&body.front());
1946 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
1952 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
1955 for (
auto *block : constructBlocks)
1956 block->dropAllReferences();
1963 for (
auto *block : constructBlocks) {
1967 "failed control flow structurization: it has uses outside of the "
1968 "enclosing selection/loop construct");
1972 for (
auto *block : constructBlocks) {
1981 auto it = blockMergeInfo.find(block);
1982 if (it != blockMergeInfo.end()) {
1988 return emitError(loc,
"failed control flow structurization: nested "
1989 "loop header block should be remapped!");
1991 Block *newContinue = it->second.continueBlock;
1995 return emitError(loc,
"failed control flow structurization: nested "
1996 "loop continue block should be remapped!");
1999 Block *newMerge = it->second.mergeBlock;
2001 newMerge = mappedTo;
2005 blockMergeInfo.
erase(it);
2006 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2015 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2016 <<
" to only contain a spirv.Branch op\n");
2020 builder.setInsertionPointToEnd(block);
2021 builder.create<spirv::BranchOp>(location, mergeBlock);
2023 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2028 LLVM_DEBUG(logger.startLine()
2029 <<
"[cf] after structurizing construct with header block "
2030 << headerBlock <<
":\n"
2039 <<
"//----- [phi] start wiring up block arguments -----//\n";
2045 for (
const auto &info : blockPhiInfo) {
2046 Block *block = info.first.first;
2047 Block *target = info.first.second;
2048 const BlockPhiInfo &phiInfo = info.second;
2050 logger.startLine() <<
"[phi] block " << block <<
"\n";
2051 logger.startLine() <<
"[phi] before creating block argument:\n";
2053 logger.startLine() <<
"\n";
2059 opBuilder.setInsertionPoint(op);
2062 blockArgs.reserve(phiInfo.size());
2063 for (uint32_t valueId : phiInfo) {
2064 if (
Value value = getValue(valueId)) {
2065 blockArgs.push_back(value);
2066 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2067 <<
" id = " << valueId <<
"\n");
2069 return emitError(unknownLoc,
"OpPhi references undefined value!");
2073 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2075 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2078 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2079 assert((branchCondOp.getTrueBlock() == target ||
2080 branchCondOp.getFalseBlock() == target) &&
2081 "expected target to be either the true or false target");
2082 if (target == branchCondOp.getTrueTarget())
2083 opBuilder.create<spirv::BranchConditionalOp>(
2084 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2085 branchCondOp.getFalseBlockArguments(),
2086 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2087 branchCondOp.getFalseTarget());
2089 opBuilder.create<spirv::BranchConditionalOp>(
2090 branchCondOp.getLoc(), branchCondOp.getCondition(),
2091 branchCondOp.getTrueBlockArguments(), blockArgs,
2092 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2093 branchCondOp.getFalseBlock());
2095 branchCondOp.erase();
2097 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2101 logger.startLine() <<
"[phi] after creating block argument:\n";
2103 logger.startLine() <<
"\n";
2106 blockPhiInfo.clear();
2111 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2116 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2119 <<
"//----- [cf] start structurizing control flow -----//\n";
2123 while (!blockMergeInfo.empty()) {
2124 Block *headerBlock = blockMergeInfo.
begin()->first;
2125 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2128 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2129 headerBlock->
print(logger.getOStream());
2130 logger.startLine() <<
"\n";
2133 auto *mergeBlock = mergeInfo.mergeBlock;
2134 assert(mergeBlock &&
"merge block cannot be nullptr");
2135 if (!mergeBlock->args_empty())
2136 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2138 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2139 mergeBlock->print(logger.getOStream());
2140 logger.startLine() <<
"\n";
2143 auto *continueBlock = mergeInfo.continueBlock;
2144 LLVM_DEBUG(
if (continueBlock) {
2145 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2146 continueBlock->print(logger.getOStream());
2147 logger.startLine() <<
"\n";
2151 blockMergeInfo.erase(blockMergeInfo.begin());
2152 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2153 blockMergeInfo, headerBlock,
2154 mergeBlock, continueBlock
2160 if (
failed(structurizer.structurize()))
2167 <<
"//--- [cf] completed structurizing control flow ---//\n";
2180 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2181 if (fileName.empty())
2182 fileName =
"<unknown>";
2194 if (operands.size() != 3)
2195 return emitError(unknownLoc,
"OpLine must have 3 operands");
2196 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2200 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2204 if (operands.size() < 2)
2205 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2207 if (!debugInfoMap.lookup(operands[0]).empty())
2209 "duplicate debug string found for result <id> ")
2212 unsigned wordIndex = 1;
2214 if (wordIndex != operands.size())
2216 "unexpected trailing words in OpString instruction");
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixNVType get(Type elementType, Scope scope, unsigned rows, unsigned columns)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.