24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/bit.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/SaveAndRestore.h"
31 #include "llvm/Support/raw_ostream.h"
36 #define DEBUG_TYPE "spirv-deserialization"
45 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion())
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (
failed(sliceInstruction(opcode, operands)))
82 if (
failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (
failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158 if (operands.size() != 1)
159 return emitError(unknownLoc,
"OpMemoryModel must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
163 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
178 if (wordIndex != words.size())
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
183 return emitError(unknownLoc,
"unknown extension: ") << extName;
185 extensions.insert(*ext);
191 if (words.size() < 2) {
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
199 if (wordIndex != words.size()) {
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
210 extensions.getArrayRef(), context));
215 if (operands.size() != 2)
216 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel
>(operands.front())));
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel
>(operands.back())));
234 if (words.size() < 2) {
236 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
238 auto decorationName =
239 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
240 if (decorationName.empty()) {
241 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
243 auto symbol = getSymbolDecoration(decorationName);
244 switch (
static_cast<spirv::Decoration
>(words[1])) {
245 case spirv::Decoration::FPFastMathMode:
246 if (words.size() != 3) {
247 return emitError(unknownLoc,
"OpDecorate with ")
248 << decorationName <<
" needs a single integer literal";
250 decorations[words[0]].set(
252 static_cast<FPFastMathMode
>(words[2])));
254 case spirv::Decoration::DescriptorSet:
255 case spirv::Decoration::Binding:
256 if (words.size() != 3) {
257 return emitError(unknownLoc,
"OpDecorate with ")
258 << decorationName <<
" needs a single integer literal";
260 decorations[words[0]].set(
261 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
263 case spirv::Decoration::BuiltIn:
264 if (words.size() != 3) {
265 return emitError(unknownLoc,
"OpDecorate with ")
266 << decorationName <<
" needs a single integer literal";
268 decorations[words[0]].set(
269 symbol, opBuilder.getStringAttr(
270 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
272 case spirv::Decoration::ArrayStride:
273 if (words.size() != 3) {
274 return emitError(unknownLoc,
"OpDecorate with ")
275 << decorationName <<
" needs a single integer literal";
277 typeDecorations[words[0]] = words[2];
279 case spirv::Decoration::LinkageAttributes: {
280 if (words.size() < 4) {
281 return emitError(unknownLoc,
"OpDecorate with ")
283 <<
" needs at least 1 string and 1 integer literal";
291 unsigned wordIndex = 2;
293 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
294 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
295 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
297 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
300 case spirv::Decoration::Aliased:
301 case spirv::Decoration::AliasedPointer:
302 case spirv::Decoration::Block:
303 case spirv::Decoration::BufferBlock:
304 case spirv::Decoration::Flat:
305 case spirv::Decoration::NonReadable:
306 case spirv::Decoration::NonWritable:
307 case spirv::Decoration::NoPerspective:
308 case spirv::Decoration::NoSignedWrap:
309 case spirv::Decoration::NoUnsignedWrap:
310 case spirv::Decoration::RelaxedPrecision:
311 case spirv::Decoration::Restrict:
312 case spirv::Decoration::RestrictPointer:
313 case spirv::Decoration::NoContraction:
314 if (words.size() != 2) {
315 return emitError(unknownLoc,
"OpDecoration with ")
316 << decorationName <<
"needs a single target <id>";
322 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
324 case spirv::Decoration::Location:
325 case spirv::Decoration::SpecId:
326 if (words.size() != 3) {
327 return emitError(unknownLoc,
"OpDecoration with ")
328 << decorationName <<
"needs a single integer literal";
330 decorations[words[0]].set(
331 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
334 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
342 if (words.size() < 3) {
344 "OpMemberDecorate must have at least 3 operands");
347 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
348 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
350 " missing offset specification in OpMemberDecorate with "
351 "Offset decoration");
354 if (words.size() > 3) {
355 decorationOperands = words.slice(3);
357 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
362 if (words.size() < 3) {
363 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
365 unsigned wordIndex = 2;
367 if (wordIndex != words.size()) {
369 "unexpected trailing words in OpMemberName instruction");
371 memberNameMap[words[0]][words[1]] = name;
377 if (!decorations.contains(argID)) {
382 spirv::DecorationAttr foundDecorationAttr;
384 for (
auto decoration :
385 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
386 spirv::Decoration::AliasedPointer,
387 spirv::Decoration::RestrictPointer}) {
389 if (decAttr.getName() !=
390 getSymbolDecoration(stringifyDecoration(decoration)))
393 if (foundDecorationAttr)
395 "more than one Aliased/Restrict decorations for "
396 "function argument with result <id> ")
404 if (!foundDecorationAttr)
405 return emitError(unknownLoc,
"unimplemented decoration support for "
406 "function argument with result <id> ")
410 foundDecorationAttr);
418 return emitError(unknownLoc,
"found function inside function");
422 if (operands.size() != 4) {
423 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
425 Type resultType = getType(operands[0]);
427 return emitError(unknownLoc,
"undefined result type from <id> ")
431 uint32_t fnID = operands[1];
432 if (funcMap.count(fnID)) {
433 return emitError(unknownLoc,
"duplicate function definition/declaration");
436 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
438 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
441 Type fnType = getType(operands[3]);
442 if (!fnType || !isa<FunctionType>(fnType)) {
443 return emitError(unknownLoc,
"unknown function type from <id> ")
446 auto functionType = cast<FunctionType>(fnType);
448 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
449 (functionType.getNumResults() == 1 &&
450 functionType.getResult(0) != resultType)) {
451 return emitError(unknownLoc,
"mismatch in function type ")
452 << functionType <<
" and return type " << resultType <<
" specified";
455 std::string fnName = getFunctionSymbol(fnID);
456 auto funcOp = opBuilder.create<spirv::FuncOp>(
457 unknownLoc, fnName, functionType, fnControl.value());
459 if (decorations.count(fnID)) {
460 for (
auto attr : decorations[fnID].getAttrs()) {
461 funcOp->setAttr(attr.getName(), attr.getValue());
464 curFunction = funcMap[fnID] = funcOp;
465 auto *entryBlock = funcOp.addEntryBlock();
468 <<
"//===-------------------------------------------===//\n";
469 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
470 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
471 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
472 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
477 argAttrs.resize(functionType.getNumInputs());
480 if (functionType.getNumInputs()) {
481 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
482 auto argType = functionType.getInput(i);
483 spirv::Opcode opcode = spirv::Opcode::OpNop;
485 if (
failed(sliceInstruction(opcode, operands,
486 spirv::Opcode::OpFunctionParameter))) {
489 if (opcode != spirv::Opcode::OpFunctionParameter) {
492 "missing OpFunctionParameter instruction for argument ")
495 if (operands.size() != 2) {
498 "expected result type and result <id> for OpFunctionParameter");
500 auto argDefinedType = getType(operands[0]);
501 if (!argDefinedType || argDefinedType != argType) {
503 "mismatch in argument type between function type "
505 << functionType <<
" and argument type definition "
506 << argDefinedType <<
" at argument " << i;
508 if (getValue(operands[1])) {
509 return emitError(unknownLoc,
"duplicate definition of result <id> ")
512 if (
failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
516 auto argValue = funcOp.getArgument(i);
517 valueMap[operands[1]] = argValue;
521 if (llvm::any_of(argAttrs, [](
Attribute attr) {
522 auto argAttr = cast<DictionaryAttr>(attr);
523 return !argAttr.empty();
530 auto linkageAttr = funcOp.getLinkageAttributes();
531 auto hasImportLinkage =
532 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
533 spirv::LinkageType::Import);
534 if (hasImportLinkage)
541 spirv::Opcode opcode = spirv::Opcode::OpNop;
549 if (
failed(sliceInstruction(opcode, instOperands,
550 spirv::Opcode::OpFunctionEnd))) {
553 if (opcode == spirv::Opcode::OpFunctionEnd) {
554 return processFunctionEnd(instOperands);
556 if (opcode != spirv::Opcode::OpLabel) {
557 return emitError(unknownLoc,
"a basic block must start with OpLabel");
559 if (instOperands.size() != 1) {
560 return emitError(unknownLoc,
"OpLabel should only have result <id>");
562 blockMap[instOperands[0]] = entryBlock;
563 if (
failed(processLabel(instOperands))) {
569 while (
succeeded(sliceInstruction(opcode, instOperands,
570 spirv::Opcode::OpFunctionEnd)) &&
571 opcode != spirv::Opcode::OpFunctionEnd) {
572 if (
failed(processInstruction(opcode, instOperands))) {
576 if (opcode != spirv::Opcode::OpFunctionEnd) {
580 return processFunctionEnd(instOperands);
586 if (!operands.empty()) {
587 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
593 if (
failed(wireUpBlockArgument()) ||
failed(structurizeControlFlow())) {
598 curFunction = std::nullopt;
603 <<
"//===-------------------------------------------===//\n";
608 std::optional<std::pair<Attribute, Type>>
609 spirv::Deserializer::getConstant(uint32_t
id) {
610 auto constIt = constantMap.find(
id);
611 if (constIt == constantMap.end())
613 return constIt->getSecond();
616 std::optional<spirv::SpecConstOperationMaterializationInfo>
617 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
618 auto constIt = specConstOperationMap.find(
id);
619 if (constIt == specConstOperationMap.end())
621 return constIt->getSecond();
624 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
625 auto funcName = nameMap.lookup(
id).str();
626 if (funcName.empty()) {
627 funcName =
"spirv_fn_" + std::to_string(
id);
632 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
633 auto constName = nameMap.lookup(
id).str();
634 if (constName.empty()) {
635 constName =
"spirv_spec_const_" + std::to_string(
id);
640 spirv::SpecConstantOp
641 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
642 TypedAttr defaultValue) {
643 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
644 auto op = opBuilder.
create<spirv::SpecConstantOp>(unknownLoc, symName,
646 if (decorations.count(resultID)) {
647 for (
auto attr : decorations[resultID].getAttrs())
648 op->
setAttr(attr.getName(), attr.getValue());
650 specConstMap[resultID] = op;
656 unsigned wordIndex = 0;
657 if (operands.size() < 3) {
660 "OpVariable needs at least 3 operands, type, <id> and storage class");
664 auto type = getType(operands[wordIndex]);
666 return emitError(unknownLoc,
"unknown result type <id> : ")
667 << operands[wordIndex];
669 auto ptrType = dyn_cast<spirv::PointerType>(type);
672 "expected a result type <id> to be a spirv.ptr, found : ")
678 auto variableID = operands[wordIndex];
679 auto variableName = nameMap.lookup(variableID).str();
680 if (variableName.empty()) {
681 variableName =
"spirv_var_" + std::to_string(variableID);
686 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
687 if (ptrType.getStorageClass() != storageClass) {
688 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
689 << type <<
" and that specified in OpVariable instruction : "
690 << stringifyStorageClass(storageClass);
697 if (wordIndex < operands.size()) {
700 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
702 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
704 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
707 return emitError(unknownLoc,
"unknown <id> ")
708 << operands[wordIndex] <<
"used as initializer";
713 if (wordIndex != operands.size()) {
715 "found more operands than expected when deserializing "
716 "OpVariable instruction, only ")
717 << wordIndex <<
" of " << operands.size() <<
" processed";
719 auto loc = createFileLineColLoc(opBuilder);
720 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
721 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
725 if (decorations.count(variableID)) {
726 for (
auto attr : decorations[variableID].getAttrs())
727 varOp->setAttr(attr.getName(), attr.getValue());
729 globalVariableMap[variableID] = varOp;
733 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
734 auto constInfo = getConstant(
id);
738 return dyn_cast<IntegerAttr>(constInfo->first);
742 if (operands.size() < 2) {
743 return emitError(unknownLoc,
"OpName needs at least 2 operands");
745 if (!nameMap.lookup(operands[0]).empty()) {
746 return emitError(unknownLoc,
"duplicate name found for result <id> ")
749 unsigned wordIndex = 1;
751 if (wordIndex != operands.size()) {
753 "unexpected trailing words in OpName instruction");
755 nameMap[operands[0]] = name;
763 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
765 if (operands.empty()) {
766 return emitError(unknownLoc,
"type instruction with opcode ")
767 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
772 if (typeMap.count(operands[0])) {
773 return emitError(unknownLoc,
"duplicate definition for result <id> ")
778 case spirv::Opcode::OpTypeVoid:
779 if (operands.size() != 1)
780 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
781 typeMap[operands[0]] = opBuilder.getNoneType();
783 case spirv::Opcode::OpTypeBool:
784 if (operands.size() != 1)
785 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
786 typeMap[operands[0]] = opBuilder.getI1Type();
788 case spirv::Opcode::OpTypeInt: {
789 if (operands.size() != 3)
791 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
801 : IntegerType::SignednessSemantics::Signless;
804 case spirv::Opcode::OpTypeFloat: {
805 if (operands.size() != 2)
806 return emitError(unknownLoc,
"OpTypeFloat must have bitwidth parameter");
809 switch (operands[1]) {
811 floatTy = opBuilder.getF16Type();
814 floatTy = opBuilder.getF32Type();
817 floatTy = opBuilder.getF64Type();
820 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
823 typeMap[operands[0]] = floatTy;
825 case spirv::Opcode::OpTypeVector: {
826 if (operands.size() != 3) {
829 "OpTypeVector must have element type and count parameters");
831 Type elementTy = getType(operands[1]);
833 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
838 case spirv::Opcode::OpTypePointer: {
839 return processOpTypePointer(operands);
841 case spirv::Opcode::OpTypeArray:
842 return processArrayType(operands);
843 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
844 return processCooperativeMatrixTypeKHR(operands);
845 case spirv::Opcode::OpTypeFunction:
846 return processFunctionType(operands);
847 case spirv::Opcode::OpTypeJointMatrixINTEL:
848 return processJointMatrixType(operands);
849 case spirv::Opcode::OpTypeImage:
850 return processImageType(operands);
851 case spirv::Opcode::OpTypeSampledImage:
852 return processSampledImageType(operands);
853 case spirv::Opcode::OpTypeRuntimeArray:
854 return processRuntimeArrayType(operands);
855 case spirv::Opcode::OpTypeStruct:
856 return processStructType(operands);
857 case spirv::Opcode::OpTypeMatrix:
858 return processMatrixType(operands);
860 return emitError(unknownLoc,
"unhandled type instruction");
867 if (operands.size() != 3)
868 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
870 auto pointeeType = getType(operands[2]);
872 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
875 uint32_t typePointerID = operands[0];
876 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
879 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
880 deferredStructIt != std::end(deferredStructTypesInfos);) {
881 for (
auto *unresolvedMemberIt =
882 std::begin(deferredStructIt->unresolvedMemberTypes);
883 unresolvedMemberIt !=
884 std::end(deferredStructIt->unresolvedMemberTypes);) {
885 if (unresolvedMemberIt->first == typePointerID) {
889 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
890 typeMap[typePointerID];
892 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
894 ++unresolvedMemberIt;
898 if (deferredStructIt->unresolvedMemberTypes.empty()) {
900 auto structType = deferredStructIt->deferredStructType;
902 assert(structType &&
"expected a spirv::StructType");
903 assert(structType.isIdentified() &&
"expected an indentified struct");
905 if (
failed(structType.trySetBody(
906 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
907 deferredStructIt->memberDecorationsInfo)))
910 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
921 if (operands.size() != 3) {
923 "OpTypeArray must have element type and count parameters");
926 Type elementTy = getType(operands[1]);
928 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
934 auto countInfo = getConstant(operands[2]);
936 return emitError(unknownLoc,
"OpTypeArray count <id> ")
937 << operands[2] <<
"can only come from normal constant right now";
940 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
941 count = intVal.getValue().getZExtValue();
943 return emitError(unknownLoc,
"OpTypeArray count must come from a "
944 "scalar integer constant instruction");
948 elementTy, count, typeDecorations.lookup(operands[0]));
954 assert(!operands.empty() &&
"No operands for processing function type");
955 if (operands.size() == 1) {
956 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
958 auto returnType = getType(operands[1]);
960 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
963 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
964 auto ty = getType(operands[i]);
966 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
968 argTypes.push_back(ty);
971 if (!isVoidType(returnType)) {
978 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
980 if (operands.size() != 6) {
982 "OpTypeCooperativeMatrixKHR must have element type, "
983 "scope, row and column parameters, and use");
986 Type elementTy = getType(operands[1]);
989 "OpTypeCooperativeMatrixKHR references undefined <id> ")
993 std::optional<spirv::Scope> scope =
994 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
998 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1002 unsigned rows = getConstantInt(operands[3]).getInt();
1003 unsigned columns = getConstantInt(operands[4]).getInt();
1005 std::optional<spirv::CooperativeMatrixUseKHR> use =
1006 spirv::symbolizeCooperativeMatrixUseKHR(
1007 getConstantInt(operands[5]).getInt());
1011 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1015 typeMap[operands[0]] =
1022 if (operands.size() != 6) {
1023 return emitError(unknownLoc,
"OpTypeJointMatrix must have element "
1024 "type and row x column parameters");
1027 Type elementTy = getType(operands[1]);
1029 return emitError(unknownLoc,
"OpTypeJointMatrix references undefined <id> ")
1033 auto scope = spirv::symbolizeScope(getConstantInt(operands[5]).getInt());
1036 "OpTypeJointMatrix references undefined scope <id> ")
1040 spirv::symbolizeMatrixLayout(getConstantInt(operands[4]).getInt());
1041 if (!matrixLayout) {
1043 "OpTypeJointMatrix references undefined scope <id> ")
1046 unsigned rows = getConstantInt(operands[2]).getInt();
1047 unsigned columns = getConstantInt(operands[3]).getInt();
1050 elementTy, scope.value(), rows, columns, matrixLayout.value());
1056 if (operands.size() != 2) {
1057 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1059 Type memberType = getType(operands[1]);
1062 "OpTypeRuntimeArray references undefined <id> ")
1066 memberType, typeDecorations.lookup(operands[0]));
1074 if (operands.empty()) {
1075 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1078 if (operands.size() == 1) {
1080 typeMap[operands[0]] =
1089 for (
auto op : llvm::drop_begin(operands, 1)) {
1090 Type memberType = getType(op);
1091 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1093 if (!memberType && !typeForwardPtr)
1094 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1098 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1100 memberTypes.push_back(memberType);
1105 if (memberDecorationMap.count(operands[0])) {
1106 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1107 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1108 if (allMemberDecorations.count(memberIndex)) {
1109 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1111 if (memberDecoration.first == spirv::Decoration::Offset) {
1113 if (offsetInfo.empty()) {
1114 offsetInfo.resize(memberTypes.size());
1116 offsetInfo[memberIndex] = memberDecoration.second[0];
1118 if (!memberDecoration.second.empty()) {
1119 memberDecorationsInfo.emplace_back(memberIndex, 1,
1120 memberDecoration.first,
1121 memberDecoration.second[0]);
1123 memberDecorationsInfo.emplace_back(memberIndex, 0,
1124 memberDecoration.first, 0);
1132 uint32_t structID = operands[0];
1133 std::string structIdentifier = nameMap.lookup(structID).str();
1135 if (structIdentifier.empty()) {
1136 assert(unresolvedMemberTypes.empty() &&
1137 "didn't expect unresolved member types");
1142 typeMap[structID] = structTy;
1144 if (!unresolvedMemberTypes.empty())
1145 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1146 memberTypes, offsetInfo,
1147 memberDecorationsInfo});
1148 else if (
failed(structTy.trySetBody(memberTypes, offsetInfo,
1149 memberDecorationsInfo)))
1160 if (operands.size() != 3) {
1162 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1163 " (result_id, column_type, and column_count)");
1166 Type elementTy = getType(operands[1]);
1169 "OpTypeMatrix references undefined column type.")
1173 uint32_t colsCount = operands[2];
1180 if (operands.size() != 2)
1182 "OpTypeForwardPointer instruction must have two operands");
1184 typeForwardPointerIDs.insert(operands[0]);
1194 if (operands.size() != 8)
1197 "OpTypeImage with non-eight operands are not supported yet");
1199 Type elementTy = getType(operands[1]);
1201 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1204 auto dim = spirv::symbolizeDim(operands[2]);
1206 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1209 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1211 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1214 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1216 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1219 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1221 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1223 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1224 if (!samplerUseInfo)
1225 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1228 auto format = spirv::symbolizeImageFormat(operands[7]);
1230 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1234 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1235 samplingInfo.value(), samplerUseInfo.value(), format.value());
1241 if (operands.size() != 2)
1242 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1244 Type elementTy = getType(operands[1]);
1247 "OpTypeSampledImage references undefined <id>: ")
1260 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1262 if (operands.size() < 2) {
1264 << opname <<
" must have type <id> and result <id>";
1266 if (operands.size() < 3) {
1268 << opname <<
" must have at least 1 more parameter";
1271 Type resultType = getType(operands[0]);
1273 return emitError(unknownLoc,
"undefined result type from <id> ")
1277 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) ->
LogicalResult {
1278 if (bitwidth == 64) {
1279 if (operands.size() == 4) {
1283 << opname <<
" should have 2 parameters for 64-bit values";
1285 if (bitwidth <= 32) {
1286 if (operands.size() == 3) {
1292 <<
" should have 1 parameter for values with no more than 32 bits";
1294 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1298 auto resultID = operands[1];
1300 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1301 auto bitwidth = intType.getWidth();
1302 if (
failed(checkOperandSizeForBitwidth(bitwidth))) {
1307 if (bitwidth == 64) {
1314 } words = {operands[2], operands[3]};
1315 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1316 }
else if (bitwidth <= 32) {
1317 value = APInt(bitwidth, operands[2],
true);
1320 auto attr = opBuilder.getIntegerAttr(intType, value);
1323 createSpecConstant(unknownLoc, resultID, attr);
1327 constantMap.try_emplace(resultID, attr, intType);
1333 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1334 auto bitwidth = floatType.getWidth();
1335 if (
failed(checkOperandSizeForBitwidth(bitwidth))) {
1340 if (floatType.isF64()) {
1347 } words = {operands[2], operands[3]};
1348 value = APFloat(llvm::bit_cast<double>(words));
1349 }
else if (floatType.isF32()) {
1350 value = APFloat(llvm::bit_cast<float>(operands[2]));
1351 }
else if (floatType.isF16()) {
1352 APInt data(16, operands[2]);
1353 value = APFloat(APFloat::IEEEhalf(), data);
1356 auto attr = opBuilder.getFloatAttr(floatType, value);
1358 createSpecConstant(unknownLoc, resultID, attr);
1362 constantMap.try_emplace(resultID, attr, floatType);
1368 return emitError(unknownLoc,
"OpConstant can only generate values of "
1369 "scalar integer or floating-point type");
1374 if (operands.size() != 2) {
1376 << (isSpec ?
"Spec" :
"") <<
"Constant"
1377 << (isTrue ?
"True" :
"False")
1378 <<
" must have type <id> and result <id>";
1381 auto attr = opBuilder.getBoolAttr(isTrue);
1382 auto resultID = operands[1];
1384 createSpecConstant(unknownLoc, resultID, attr);
1388 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1396 if (operands.size() < 2) {
1398 "OpConstantComposite must have type <id> and result <id>");
1400 if (operands.size() < 3) {
1402 "OpConstantComposite must have at least 1 parameter");
1405 Type resultType = getType(operands[0]);
1407 return emitError(unknownLoc,
"undefined result type from <id> ")
1412 elements.reserve(operands.size() - 2);
1413 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1414 auto elementInfo = getConstant(operands[i]);
1416 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1417 << operands[i] <<
" must come from a normal constant";
1419 elements.push_back(elementInfo->first);
1422 auto resultID = operands[1];
1423 if (
auto vectorType = dyn_cast<VectorType>(resultType)) {
1427 constantMap.try_emplace(resultID, attr, resultType);
1428 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1429 auto attr = opBuilder.getArrayAttr(elements);
1430 constantMap.try_emplace(resultID, attr, resultType);
1432 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1441 if (operands.size() < 2) {
1443 "OpConstantComposite must have type <id> and result <id>");
1445 if (operands.size() < 3) {
1447 "OpConstantComposite must have at least 1 parameter");
1450 Type resultType = getType(operands[0]);
1452 return emitError(unknownLoc,
"undefined result type from <id> ")
1456 auto resultID = operands[1];
1457 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1460 elements.reserve(operands.size() - 2);
1461 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1462 auto elementInfo = getSpecConstant(operands[i]);
1466 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1468 opBuilder.getArrayAttr(elements));
1469 specConstCompositeMap[resultID] = op;
1476 if (operands.size() < 3)
1477 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1478 "result <id>, and operand opcode");
1480 uint32_t resultTypeID = operands[0];
1482 if (!getType(resultTypeID))
1483 return emitError(unknownLoc,
"undefined result type from <id> ")
1486 uint32_t resultID = operands[1];
1487 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1488 auto emplaceResult = specConstOperationMap.try_emplace(
1490 SpecConstOperationMaterializationInfo{
1491 enclosedOpcode, resultTypeID,
1494 if (!emplaceResult.second)
1495 return emitError(unknownLoc,
"value with <id>: ")
1496 << resultID <<
" is probably defined before.";
1501 Value spirv::Deserializer::materializeSpecConstantOperation(
1502 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1505 Type resultType = getType(resultTypeID);
1518 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1519 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1522 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1523 enclosedOpResultTypeAndOperands.push_back(fakeID);
1524 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1525 enclosedOpOperands.end());
1532 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1539 auto loc = createFileLineColLoc(opBuilder);
1540 auto specConstOperationOp =
1541 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1543 Region &body = specConstOperationOp.getBody();
1545 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1552 opBuilder.setInsertionPointToEnd(&block);
1554 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1555 return specConstOperationOp.getResult();
1560 if (operands.size() != 2) {
1562 "OpConstantNull must have type <id> and result <id>");
1565 Type resultType = getType(operands[0]);
1567 return emitError(unknownLoc,
"undefined result type from <id> ")
1571 auto resultID = operands[1];
1572 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1573 auto attr = opBuilder.getZeroAttr(resultType);
1576 constantMap.try_emplace(resultID, attr, resultType);
1580 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1588 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1589 if (
auto *block = getBlock(
id)) {
1590 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1591 <<
" @ " << block <<
"\n");
1598 auto *block = curFunction->addBlock();
1599 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1600 <<
" @ " << block <<
"\n");
1601 return blockMap[id] = block;
1606 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1609 if (operands.size() != 1) {
1610 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1613 auto *target = getOrCreateBlock(operands[0]);
1614 auto loc = createFileLineColLoc(opBuilder);
1618 opBuilder.create<spirv::BranchOp>(loc, target);
1628 "OpBranchConditional must appear inside a block");
1631 if (operands.size() != 3 && operands.size() != 5) {
1633 "OpBranchConditional must have condition, true label, "
1634 "false label, and optionally two branch weights");
1637 auto condition = getValue(operands[0]);
1638 auto *trueBlock = getOrCreateBlock(operands[1]);
1639 auto *falseBlock = getOrCreateBlock(operands[2]);
1641 std::optional<std::pair<uint32_t, uint32_t>> weights;
1642 if (operands.size() == 5) {
1643 weights = std::make_pair(operands[3], operands[4]);
1648 auto loc = createFileLineColLoc(opBuilder);
1649 opBuilder.create<spirv::BranchConditionalOp>(
1650 loc, condition, trueBlock,
1660 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1663 if (operands.size() != 1) {
1664 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1667 auto labelID = operands[0];
1669 auto *block = getOrCreateBlock(labelID);
1670 LLVM_DEBUG(logger.startLine()
1671 <<
"[block] populating block " << block <<
"\n");
1673 assert(block->
empty() &&
"re-deserialize the same block!");
1675 opBuilder.setInsertionPointToStart(block);
1676 blockMap[labelID] = curBlock = block;
1684 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1687 if (operands.size() < 2) {
1690 "OpSelectionMerge must specify merge target and selection control");
1693 auto *mergeBlock = getOrCreateBlock(operands[0]);
1694 auto loc = createFileLineColLoc(opBuilder);
1695 auto selectionControl = operands[1];
1697 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1701 "a block cannot have more than one OpSelectionMerge instruction");
1710 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1713 if (operands.size() < 3) {
1714 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1715 "continue target and loop control");
1718 auto *mergeBlock = getOrCreateBlock(operands[0]);
1719 auto *continueBlock = getOrCreateBlock(operands[1]);
1720 auto loc = createFileLineColLoc(opBuilder);
1721 uint32_t loopControl = operands[2];
1724 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1728 "a block cannot have more than one OpLoopMerge instruction");
1736 return emitError(unknownLoc,
"OpPhi must appear in a block");
1739 if (operands.size() < 4) {
1740 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1741 "and variable-parent pairs");
1745 Type blockArgType = getType(operands[0]);
1746 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1747 valueMap[operands[1]] = blockArg;
1748 LLVM_DEBUG(logger.startLine()
1749 <<
"[phi] created block argument " << blockArg
1750 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1754 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1755 uint32_t value = operands[i];
1756 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1757 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1758 blockPhiInfo[predecessorTargetPair].
push_back(value);
1759 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1760 <<
" with arg id = " << value <<
"\n");
1769 class ControlFlowStructurizer {
1772 ControlFlowStructurizer(
Location loc, uint32_t control,
1775 llvm::ScopedPrinter &logger)
1776 : location(loc), control(control), blockMergeInfo(mergeInfo),
1777 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1780 ControlFlowStructurizer(
Location loc, uint32_t control,
1783 : location(loc), control(control), blockMergeInfo(mergeInfo),
1784 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1799 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1802 spirv::LoopOp createLoopOp(uint32_t loopControl);
1805 void collectBlocksInConstruct();
1814 Block *continueBlock;
1820 llvm::ScopedPrinter &logger;
1826 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1829 OpBuilder builder(&mergeBlock->front());
1831 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1832 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1833 selectionOp.addMergeBlock(builder);
1838 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1841 OpBuilder builder(&mergeBlock->front());
1843 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1844 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1845 loopOp.addEntryAndMergeBlock(builder);
1850 void ControlFlowStructurizer::collectBlocksInConstruct() {
1851 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1854 constructBlocks.insert(headerBlock);
1858 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1859 for (
auto *successor : constructBlocks[i]->getSuccessors())
1860 if (successor != mergeBlock)
1861 constructBlocks.insert(successor);
1867 bool isLoop = continueBlock !=
nullptr;
1869 if (
auto loopOp = createLoopOp(control))
1870 op = loopOp.getOperation();
1872 if (
auto selectionOp = createSelectionOp(control))
1873 op = selectionOp.getOperation();
1882 mapper.
map(mergeBlock, &body.
back());
1884 collectBlocksInConstruct();
1906 for (
auto *block : constructBlocks) {
1909 auto *newBlock = builder.createBlock(&body.
back());
1910 mapper.
map(block, newBlock);
1911 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
1912 <<
" from block " << block <<
"\n");
1916 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1917 mapper.
map(blockArg, newArg);
1918 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
1919 << blockArg <<
" to " << newArg <<
"\n");
1922 LLVM_DEBUG(logger.startLine()
1923 <<
"[cf] block " << block <<
" is a function entry block\n");
1926 for (
auto &op : *block)
1927 newBlock->push_back(op.
clone(mapper));
1931 auto remapOperands = [&](
Operation *op) {
1934 operand.set(mappedOp);
1937 succOp.set(mappedOp);
1939 for (
auto &block : body)
1940 block.walk(remapOperands);
1948 headerBlock->replaceAllUsesWith(mergeBlock);
1951 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
1952 headerBlock->getParentOp()->
print(logger.getOStream());
1953 logger.startLine() <<
"\n";
1957 if (!mergeBlock->args_empty()) {
1958 return mergeBlock->getParentOp()->emitError(
1959 "OpPhi in loop merge block unsupported");
1966 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1971 if (!headerBlock->args_empty())
1972 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1976 builder.setInsertionPointToEnd(&body.front());
1977 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
1983 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
1986 for (
auto *block : constructBlocks)
1987 block->dropAllReferences();
1994 for (
auto *block : constructBlocks) {
1998 "failed control flow structurization: it has uses outside of the "
1999 "enclosing selection/loop construct");
2003 for (
auto *block : constructBlocks) {
2012 auto it = blockMergeInfo.find(block);
2013 if (it != blockMergeInfo.end()) {
2019 return emitError(loc,
"failed control flow structurization: nested "
2020 "loop header block should be remapped!");
2022 Block *newContinue = it->second.continueBlock;
2026 return emitError(loc,
"failed control flow structurization: nested "
2027 "loop continue block should be remapped!");
2030 Block *newMerge = it->second.mergeBlock;
2032 newMerge = mappedTo;
2036 blockMergeInfo.
erase(it);
2037 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2046 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2047 <<
" to only contain a spirv.Branch op\n");
2051 builder.setInsertionPointToEnd(block);
2052 builder.create<spirv::BranchOp>(location, mergeBlock);
2054 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2059 LLVM_DEBUG(logger.startLine()
2060 <<
"[cf] after structurizing construct with header block "
2061 << headerBlock <<
":\n"
2070 <<
"//----- [phi] start wiring up block arguments -----//\n";
2076 for (
const auto &info : blockPhiInfo) {
2077 Block *block = info.first.first;
2078 Block *target = info.first.second;
2079 const BlockPhiInfo &phiInfo = info.second;
2081 logger.startLine() <<
"[phi] block " << block <<
"\n";
2082 logger.startLine() <<
"[phi] before creating block argument:\n";
2084 logger.startLine() <<
"\n";
2090 opBuilder.setInsertionPoint(op);
2093 blockArgs.reserve(phiInfo.size());
2094 for (uint32_t valueId : phiInfo) {
2095 if (
Value value = getValue(valueId)) {
2096 blockArgs.push_back(value);
2097 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2098 <<
" id = " << valueId <<
"\n");
2100 return emitError(unknownLoc,
"OpPhi references undefined value!");
2104 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2106 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2109 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2110 assert((branchCondOp.getTrueBlock() == target ||
2111 branchCondOp.getFalseBlock() == target) &&
2112 "expected target to be either the true or false target");
2113 if (target == branchCondOp.getTrueTarget())
2114 opBuilder.create<spirv::BranchConditionalOp>(
2115 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2116 branchCondOp.getFalseBlockArguments(),
2117 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2118 branchCondOp.getFalseTarget());
2120 opBuilder.create<spirv::BranchConditionalOp>(
2121 branchCondOp.getLoc(), branchCondOp.getCondition(),
2122 branchCondOp.getTrueBlockArguments(), blockArgs,
2123 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2124 branchCondOp.getFalseBlock());
2126 branchCondOp.erase();
2128 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2132 logger.startLine() <<
"[phi] after creating block argument:\n";
2134 logger.startLine() <<
"\n";
2137 blockPhiInfo.clear();
2142 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2147 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2150 <<
"//----- [cf] start structurizing control flow -----//\n";
2154 while (!blockMergeInfo.empty()) {
2155 Block *headerBlock = blockMergeInfo.
begin()->first;
2156 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2159 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2160 headerBlock->
print(logger.getOStream());
2161 logger.startLine() <<
"\n";
2164 auto *mergeBlock = mergeInfo.mergeBlock;
2165 assert(mergeBlock &&
"merge block cannot be nullptr");
2166 if (!mergeBlock->args_empty())
2167 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2169 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2170 mergeBlock->print(logger.getOStream());
2171 logger.startLine() <<
"\n";
2174 auto *continueBlock = mergeInfo.continueBlock;
2175 LLVM_DEBUG(
if (continueBlock) {
2176 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2177 continueBlock->print(logger.getOStream());
2178 logger.startLine() <<
"\n";
2182 blockMergeInfo.erase(blockMergeInfo.begin());
2183 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2184 blockMergeInfo, headerBlock,
2185 mergeBlock, continueBlock
2191 if (
failed(structurizer.structurize()))
2198 <<
"//--- [cf] completed structurizing control flow ---//\n";
2211 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2212 if (fileName.empty())
2213 fileName =
"<unknown>";
2225 if (operands.size() != 3)
2226 return emitError(unknownLoc,
"OpLine must have 3 operands");
2227 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2231 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2235 if (operands.size() < 2)
2236 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2238 if (!debugInfoMap.lookup(operands[0]).empty())
2240 "duplicate debug string found for result <id> ")
2243 unsigned wordIndex = 1;
2245 if (wordIndex != operands.size())
2247 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static JointMatrixINTELType get(Type elementType, Scope scope, unsigned rows, unsigned columns, MatrixLayout matrixLayout)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
Include the generated interface declarations.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.