23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (failed(sliceInstruction(opcode, operands)))
82 if (failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117 LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158 if (operands.size() != 1)
159 return emitError(unknownLoc,
"OpCapability must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
163 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
178 if (wordIndex != words.size())
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
183 return emitError(unknownLoc,
"unknown extension: ") << extName;
185 extensions.insert(*ext);
191 if (words.size() < 2) {
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
199 if (wordIndex != words.size()) {
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
210 extensions.getArrayRef(), context));
215 if (operands.size() != 2)
216 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel
>(operands.front())));
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel
>(operands.back())));
230 template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235 if (words.size() != 4) {
236 return emitError(loc,
"OpDecoration with ")
237 << decorationName <<
"needs a cache control integer literal and a "
238 << cacheControlKind <<
" cache control literal";
240 unsigned cacheLevel = words[2];
241 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
242 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
246 llvm::append_range(attrs, attrList);
247 attrs.push_back(value);
248 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
256 if (words.size() < 2) {
258 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
260 auto decorationName =
261 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
262 if (decorationName.empty()) {
263 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
265 auto symbol = getSymbolDecoration(decorationName);
266 switch (
static_cast<spirv::Decoration
>(words[1])) {
267 case spirv::Decoration::FPFastMathMode:
268 if (words.size() != 3) {
269 return emitError(unknownLoc,
"OpDecorate with ")
270 << decorationName <<
" needs a single integer literal";
272 decorations[words[0]].set(
274 static_cast<FPFastMathMode
>(words[2])));
276 case spirv::Decoration::FPRoundingMode:
277 if (words.size() != 3) {
278 return emitError(unknownLoc,
"OpDecorate with ")
279 << decorationName <<
" needs a single integer literal";
281 decorations[words[0]].set(
283 static_cast<FPRoundingMode
>(words[2])));
285 case spirv::Decoration::DescriptorSet:
286 case spirv::Decoration::Binding:
287 if (words.size() != 3) {
288 return emitError(unknownLoc,
"OpDecorate with ")
289 << decorationName <<
" needs a single integer literal";
291 decorations[words[0]].set(
292 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
294 case spirv::Decoration::BuiltIn:
295 if (words.size() != 3) {
296 return emitError(unknownLoc,
"OpDecorate with ")
297 << decorationName <<
" needs a single integer literal";
299 decorations[words[0]].set(
300 symbol, opBuilder.getStringAttr(
301 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
303 case spirv::Decoration::ArrayStride:
304 if (words.size() != 3) {
305 return emitError(unknownLoc,
"OpDecorate with ")
306 << decorationName <<
" needs a single integer literal";
308 typeDecorations[words[0]] = words[2];
310 case spirv::Decoration::LinkageAttributes: {
311 if (words.size() < 4) {
312 return emitError(unknownLoc,
"OpDecorate with ")
314 <<
" needs at least 1 string and 1 integer literal";
322 unsigned wordIndex = 2;
324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
331 case spirv::Decoration::Aliased:
332 case spirv::Decoration::AliasedPointer:
333 case spirv::Decoration::Block:
334 case spirv::Decoration::BufferBlock:
335 case spirv::Decoration::Flat:
336 case spirv::Decoration::NonReadable:
337 case spirv::Decoration::NonWritable:
338 case spirv::Decoration::NoPerspective:
339 case spirv::Decoration::NoSignedWrap:
340 case spirv::Decoration::NoUnsignedWrap:
341 case spirv::Decoration::RelaxedPrecision:
342 case spirv::Decoration::Restrict:
343 case spirv::Decoration::RestrictPointer:
344 case spirv::Decoration::NoContraction:
345 case spirv::Decoration::Constant:
346 if (words.size() != 2) {
347 return emitError(unknownLoc,
"OpDecoration with ")
348 << decorationName <<
"needs a single target <id>";
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(unknownLoc,
"OpDecoration with ")
360 << decorationName <<
"needs a single integer literal";
362 decorations[words[0]].set(
363 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
365 case spirv::Decoration::CacheControlLoadINTEL: {
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
374 case spirv::Decoration::CacheControlStoreINTEL: {
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
384 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
392 if (words.size() < 3) {
394 "OpMemberDecorate must have at least 3 operands");
397 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
400 " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
404 if (words.size() > 3) {
405 decorationOperands = words.slice(3);
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
412 if (words.size() < 3) {
413 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
415 unsigned wordIndex = 2;
417 if (wordIndex != words.size()) {
419 "unexpected trailing words in OpMemberName instruction");
421 memberNameMap[words[0]][words[1]] = name;
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
427 if (!decorations.contains(argID)) {
432 spirv::DecorationAttr foundDecorationAttr;
434 for (
auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
439 if (decAttr.getName() !=
440 getSymbolDecoration(stringifyDecoration(decoration)))
443 if (foundDecorationAttr)
445 "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
453 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454 spirv::Decoration::RelaxedPrecision))) {
459 if (foundDecorationAttr)
460 return emitError(unknownLoc,
"already found a decoration for function "
461 "argument with result <id> ")
465 context, spirv::Decoration::RelaxedPrecision);
469 if (!foundDecorationAttr)
470 return emitError(unknownLoc,
"unimplemented decoration support for "
471 "function argument with result <id> ")
475 foundDecorationAttr);
483 return emitError(unknownLoc,
"found function inside function");
487 if (operands.size() != 4) {
488 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
492 return emitError(unknownLoc,
"undefined result type from <id> ")
496 uint32_t fnID = operands[1];
497 if (funcMap.count(fnID)) {
498 return emitError(unknownLoc,
"duplicate function definition/declaration");
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
503 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
507 if (!fnType || !isa<FunctionType>(fnType)) {
508 return emitError(unknownLoc,
"unknown function type from <id> ")
511 auto functionType = cast<FunctionType>(fnType);
513 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(0) != resultType)) {
516 return emitError(unknownLoc,
"mismatch in function type ")
517 << functionType <<
" and return type " << resultType <<
" specified";
520 std::string fnName = getFunctionSymbol(fnID);
521 auto funcOp = opBuilder.create<spirv::FuncOp>(
522 unknownLoc, fnName, functionType, fnControl.value());
524 if (decorations.count(fnID)) {
525 for (
auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(attr.getName(), attr.getValue());
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
533 <<
"//===-------------------------------------------===//\n";
534 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
535 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
536 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
537 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
542 argAttrs.resize(functionType.getNumInputs());
545 if (functionType.getNumInputs()) {
546 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
550 if (failed(sliceInstruction(opcode, operands,
551 spirv::Opcode::OpFunctionParameter))) {
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
557 "missing OpFunctionParameter instruction for argument ")
560 if (operands.size() != 2) {
563 "expected result type and result <id> for OpFunctionParameter");
565 auto argDefinedType =
getType(operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
568 "mismatch in argument type between function type "
570 << functionType <<
" and argument type definition "
571 << argDefinedType <<
" at argument " << i;
573 if (getValue(operands[1])) {
574 return emitError(unknownLoc,
"duplicate definition of result <id> ")
577 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
581 auto argValue = funcOp.getArgument(i);
582 valueMap[operands[1]] = argValue;
586 if (llvm::any_of(argAttrs, [](
Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(attr);
588 return !argAttr.empty();
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
614 if (failed(sliceInstruction(opcode, instOperands,
615 spirv::Opcode::OpFunctionEnd))) {
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
619 return processFunctionEnd(instOperands);
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(unknownLoc,
"a basic block must start with OpLabel");
624 if (instOperands.size() != 1) {
625 return emitError(unknownLoc,
"OpLabel should only have result <id>");
627 blockMap[instOperands[0]] = entryBlock;
628 if (failed(processLabel(instOperands))) {
634 while (succeeded(sliceInstruction(opcode, instOperands,
635 spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
637 if (failed(processInstruction(opcode, instOperands))) {
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
645 return processFunctionEnd(instOperands);
651 if (!operands.empty()) {
652 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
658 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
663 curFunction = std::nullopt;
668 <<
"//===-------------------------------------------===//\n";
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t
id) {
675 auto constIt = constantMap.find(
id);
676 if (constIt == constantMap.end())
678 return constIt->getSecond();
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
683 auto constIt = specConstOperationMap.find(
id);
684 if (constIt == specConstOperationMap.end())
686 return constIt->getSecond();
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
690 auto funcName = nameMap.lookup(
id).str();
691 if (funcName.empty()) {
692 funcName =
"spirv_fn_" + std::to_string(
id);
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
698 auto constName = nameMap.lookup(
id).str();
699 if (constName.empty()) {
700 constName =
"spirv_spec_const_" + std::to_string(
id);
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
707 TypedAttr defaultValue) {
708 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
711 if (decorations.count(resultID)) {
712 for (
auto attr : decorations[resultID].getAttrs())
713 op->setAttr(attr.getName(), attr.getValue());
715 specConstMap[resultID] = op;
721 unsigned wordIndex = 0;
722 if (operands.size() < 3) {
725 "OpVariable needs at least 3 operands, type, <id> and storage class");
729 auto type =
getType(operands[wordIndex]);
731 return emitError(unknownLoc,
"unknown result type <id> : ")
732 << operands[wordIndex];
734 auto ptrType = dyn_cast<spirv::PointerType>(type);
737 "expected a result type <id> to be a spirv.ptr, found : ")
743 auto variableID = operands[wordIndex];
744 auto variableName = nameMap.lookup(variableID).str();
745 if (variableName.empty()) {
746 variableName =
"spirv_var_" + std::to_string(variableID);
751 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
752 if (ptrType.getStorageClass() != storageClass) {
753 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
754 << type <<
" and that specified in OpVariable instruction : "
755 << stringifyStorageClass(storageClass);
762 if (wordIndex < operands.size()) {
765 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
767 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
769 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
772 return emitError(unknownLoc,
"unknown <id> ")
773 << operands[wordIndex] <<
"used as initializer";
778 if (wordIndex != operands.size()) {
780 "found more operands than expected when deserializing "
781 "OpVariable instruction, only ")
782 << wordIndex <<
" of " << operands.size() <<
" processed";
784 auto loc = createFileLineColLoc(opBuilder);
785 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
790 if (decorations.count(variableID)) {
791 for (
auto attr : decorations[variableID].getAttrs())
792 varOp->setAttr(attr.getName(), attr.getValue());
794 globalVariableMap[variableID] = varOp;
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
799 auto constInfo = getConstant(
id);
803 return dyn_cast<IntegerAttr>(constInfo->first);
807 if (operands.size() < 2) {
808 return emitError(unknownLoc,
"OpName needs at least 2 operands");
810 if (!nameMap.lookup(operands[0]).empty()) {
811 return emitError(unknownLoc,
"duplicate name found for result <id> ")
814 unsigned wordIndex = 1;
816 if (wordIndex != operands.size()) {
818 "unexpected trailing words in OpName instruction");
820 nameMap[operands[0]] = name;
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
830 if (operands.empty()) {
831 return emitError(unknownLoc,
"type instruction with opcode ")
832 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
837 if (typeMap.count(operands[0])) {
838 return emitError(unknownLoc,
"duplicate definition for result <id> ")
843 case spirv::Opcode::OpTypeVoid:
844 if (operands.size() != 1)
845 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
846 typeMap[operands[0]] = opBuilder.getNoneType();
848 case spirv::Opcode::OpTypeBool:
849 if (operands.size() != 1)
850 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
851 typeMap[operands[0]] = opBuilder.getI1Type();
853 case spirv::Opcode::OpTypeInt: {
854 if (operands.size() != 3)
856 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
866 : IntegerType::SignednessSemantics::Signless;
869 case spirv::Opcode::OpTypeFloat: {
870 if (operands.size() != 2)
871 return emitError(unknownLoc,
"OpTypeFloat must have bitwidth parameter");
874 switch (operands[1]) {
876 floatTy = opBuilder.getF16Type();
879 floatTy = opBuilder.getF32Type();
882 floatTy = opBuilder.getF64Type();
885 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
888 typeMap[operands[0]] = floatTy;
890 case spirv::Opcode::OpTypeVector: {
891 if (operands.size() != 3) {
894 "OpTypeVector must have element type and count parameters");
898 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
903 case spirv::Opcode::OpTypePointer: {
904 return processOpTypePointer(operands);
906 case spirv::Opcode::OpTypeArray:
907 return processArrayType(operands);
908 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
909 return processCooperativeMatrixTypeKHR(operands);
910 case spirv::Opcode::OpTypeFunction:
911 return processFunctionType(operands);
912 case spirv::Opcode::OpTypeImage:
913 return processImageType(operands);
914 case spirv::Opcode::OpTypeSampledImage:
915 return processSampledImageType(operands);
916 case spirv::Opcode::OpTypeRuntimeArray:
917 return processRuntimeArrayType(operands);
918 case spirv::Opcode::OpTypeStruct:
919 return processStructType(operands);
920 case spirv::Opcode::OpTypeMatrix:
921 return processMatrixType(operands);
923 return emitError(unknownLoc,
"unhandled type instruction");
930 if (operands.size() != 3)
931 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
933 auto pointeeType =
getType(operands[2]);
935 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
938 uint32_t typePointerID = operands[0];
939 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
942 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
943 deferredStructIt != std::end(deferredStructTypesInfos);) {
944 for (
auto *unresolvedMemberIt =
945 std::begin(deferredStructIt->unresolvedMemberTypes);
946 unresolvedMemberIt !=
947 std::end(deferredStructIt->unresolvedMemberTypes);) {
948 if (unresolvedMemberIt->first == typePointerID) {
952 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
953 typeMap[typePointerID];
955 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
957 ++unresolvedMemberIt;
961 if (deferredStructIt->unresolvedMemberTypes.empty()) {
963 auto structType = deferredStructIt->deferredStructType;
965 assert(structType &&
"expected a spirv::StructType");
966 assert(structType.isIdentified() &&
"expected an indentified struct");
968 if (failed(structType.trySetBody(
969 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
970 deferredStructIt->memberDecorationsInfo)))
973 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
984 if (operands.size() != 3) {
986 "OpTypeArray must have element type and count parameters");
991 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
997 auto countInfo = getConstant(operands[2]);
999 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1000 << operands[2] <<
"can only come from normal constant right now";
1003 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1004 count = intVal.getValue().getZExtValue();
1006 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1007 "scalar integer constant instruction");
1011 elementTy, count, typeDecorations.lookup(operands[0]));
1017 assert(!operands.empty() &&
"No operands for processing function type");
1018 if (operands.size() == 1) {
1019 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1021 auto returnType =
getType(operands[1]);
1023 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1026 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1027 auto ty =
getType(operands[i]);
1029 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1031 argTypes.push_back(ty);
1034 if (!isVoidType(returnType)) {
1041 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1043 if (operands.size() != 6) {
1045 "OpTypeCooperativeMatrixKHR must have element type, "
1046 "scope, row and column parameters, and use");
1052 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1056 std::optional<spirv::Scope> scope =
1057 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1061 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1065 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1066 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1067 IntegerAttr useAttr = getConstantInt(operands[5]);
1070 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1071 "undefined constant <id> ")
1075 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1076 "references undefined constant <id> ")
1080 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1081 "undefined constant <id> ")
1084 unsigned rows = rowsAttr.getInt();
1085 unsigned columns = columnsAttr.getInt();
1087 std::optional<spirv::CooperativeMatrixUseKHR> use =
1088 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1092 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1096 typeMap[operands[0]] =
1103 if (operands.size() != 2) {
1104 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1109 "OpTypeRuntimeArray references undefined <id> ")
1113 memberType, typeDecorations.lookup(operands[0]));
1121 if (operands.empty()) {
1122 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1125 if (operands.size() == 1) {
1127 typeMap[operands[0]] =
1136 for (
auto op : llvm::drop_begin(operands, 1)) {
1138 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1140 if (!memberType && !typeForwardPtr)
1141 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1145 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1147 memberTypes.push_back(memberType);
1152 if (memberDecorationMap.count(operands[0])) {
1153 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1154 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1155 if (allMemberDecorations.count(memberIndex)) {
1156 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1158 if (memberDecoration.first == spirv::Decoration::Offset) {
1160 if (offsetInfo.empty()) {
1161 offsetInfo.resize(memberTypes.size());
1163 offsetInfo[memberIndex] = memberDecoration.second[0];
1165 if (!memberDecoration.second.empty()) {
1166 memberDecorationsInfo.emplace_back(memberIndex, 1,
1167 memberDecoration.first,
1168 memberDecoration.second[0]);
1170 memberDecorationsInfo.emplace_back(memberIndex, 0,
1171 memberDecoration.first, 0);
1179 uint32_t structID = operands[0];
1180 std::string structIdentifier = nameMap.lookup(structID).str();
1182 if (structIdentifier.empty()) {
1183 assert(unresolvedMemberTypes.empty() &&
1184 "didn't expect unresolved member types");
1189 typeMap[structID] = structTy;
1191 if (!unresolvedMemberTypes.empty())
1192 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1193 memberTypes, offsetInfo,
1194 memberDecorationsInfo});
1195 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1196 memberDecorationsInfo)))
1207 if (operands.size() != 3) {
1209 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1210 " (result_id, column_type, and column_count)");
1216 "OpTypeMatrix references undefined column type.")
1220 uint32_t colsCount = operands[2];
1227 if (operands.size() != 2)
1229 "OpTypeForwardPointer instruction must have two operands");
1231 typeForwardPointerIDs.insert(operands[0]);
1241 if (operands.size() != 8)
1244 "OpTypeImage with non-eight operands are not supported yet");
1248 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1251 auto dim = spirv::symbolizeDim(operands[2]);
1253 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1256 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1258 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1261 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1263 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1266 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1268 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1270 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1271 if (!samplerUseInfo)
1272 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1275 auto format = spirv::symbolizeImageFormat(operands[7]);
1277 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1281 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1282 samplingInfo.value(), samplerUseInfo.value(), format.value());
1288 if (operands.size() != 2)
1289 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1294 "OpTypeSampledImage references undefined <id>: ")
1307 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1309 if (operands.size() < 2) {
1311 << opname <<
" must have type <id> and result <id>";
1313 if (operands.size() < 3) {
1315 << opname <<
" must have at least 1 more parameter";
1320 return emitError(unknownLoc,
"undefined result type from <id> ")
1324 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1325 if (bitwidth == 64) {
1326 if (operands.size() == 4) {
1330 << opname <<
" should have 2 parameters for 64-bit values";
1332 if (bitwidth <= 32) {
1333 if (operands.size() == 3) {
1339 <<
" should have 1 parameter for values with no more than 32 bits";
1341 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1345 auto resultID = operands[1];
1347 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1348 auto bitwidth = intType.getWidth();
1349 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1354 if (bitwidth == 64) {
1361 } words = {operands[2], operands[3]};
1362 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1363 }
else if (bitwidth <= 32) {
1364 value = APInt(bitwidth, operands[2],
true,
1368 auto attr = opBuilder.getIntegerAttr(intType, value);
1371 createSpecConstant(unknownLoc, resultID, attr);
1375 constantMap.try_emplace(resultID, attr, intType);
1381 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1382 auto bitwidth = floatType.getWidth();
1383 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1388 if (floatType.isF64()) {
1395 } words = {operands[2], operands[3]};
1396 value = APFloat(llvm::bit_cast<double>(words));
1397 }
else if (floatType.isF32()) {
1398 value = APFloat(llvm::bit_cast<float>(operands[2]));
1399 }
else if (floatType.isF16()) {
1400 APInt data(16, operands[2]);
1401 value = APFloat(APFloat::IEEEhalf(), data);
1404 auto attr = opBuilder.getFloatAttr(floatType, value);
1406 createSpecConstant(unknownLoc, resultID, attr);
1410 constantMap.try_emplace(resultID, attr, floatType);
1416 return emitError(unknownLoc,
"OpConstant can only generate values of "
1417 "scalar integer or floating-point type");
1420 LogicalResult spirv::Deserializer::processConstantBool(
1422 if (operands.size() != 2) {
1424 << (isSpec ?
"Spec" :
"") <<
"Constant"
1425 << (isTrue ?
"True" :
"False")
1426 <<
" must have type <id> and result <id>";
1429 auto attr = opBuilder.getBoolAttr(isTrue);
1430 auto resultID = operands[1];
1432 createSpecConstant(unknownLoc, resultID, attr);
1436 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1444 if (operands.size() < 2) {
1446 "OpConstantComposite must have type <id> and result <id>");
1448 if (operands.size() < 3) {
1450 "OpConstantComposite must have at least 1 parameter");
1455 return emitError(unknownLoc,
"undefined result type from <id> ")
1460 elements.reserve(operands.size() - 2);
1461 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1462 auto elementInfo = getConstant(operands[i]);
1464 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1465 << operands[i] <<
" must come from a normal constant";
1467 elements.push_back(elementInfo->first);
1470 auto resultID = operands[1];
1471 if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1475 constantMap.try_emplace(resultID, attr, shapedType);
1476 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1477 auto attr = opBuilder.getArrayAttr(elements);
1478 constantMap.try_emplace(resultID, attr, resultType);
1480 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1489 if (operands.size() < 2) {
1491 "OpConstantComposite must have type <id> and result <id>");
1493 if (operands.size() < 3) {
1495 "OpConstantComposite must have at least 1 parameter");
1500 return emitError(unknownLoc,
"undefined result type from <id> ")
1504 auto resultID = operands[1];
1505 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1508 elements.reserve(operands.size() - 2);
1509 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1510 auto elementInfo = getSpecConstant(operands[i]);
1514 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1516 opBuilder.getArrayAttr(elements));
1517 specConstCompositeMap[resultID] = op;
1524 if (operands.size() < 3)
1525 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1526 "result <id>, and operand opcode");
1528 uint32_t resultTypeID = operands[0];
1531 return emitError(unknownLoc,
"undefined result type from <id> ")
1534 uint32_t resultID = operands[1];
1535 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1536 auto emplaceResult = specConstOperationMap.try_emplace(
1538 SpecConstOperationMaterializationInfo{
1539 enclosedOpcode, resultTypeID,
1542 if (!emplaceResult.second)
1543 return emitError(unknownLoc,
"value with <id>: ")
1544 << resultID <<
" is probably defined before.";
1549 Value spirv::Deserializer::materializeSpecConstantOperation(
1550 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1566 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1567 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1570 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1571 enclosedOpResultTypeAndOperands.push_back(fakeID);
1572 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1573 enclosedOpOperands.end());
1580 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1587 auto loc = createFileLineColLoc(opBuilder);
1588 auto specConstOperationOp =
1589 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1591 Region &body = specConstOperationOp.getBody();
1593 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1600 opBuilder.setInsertionPointToEnd(&block);
1602 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1603 return specConstOperationOp.getResult();
1608 if (operands.size() != 2) {
1610 "OpConstantNull must have type <id> and result <id>");
1615 return emitError(unknownLoc,
"undefined result type from <id> ")
1619 auto resultID = operands[1];
1620 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1621 auto attr = opBuilder.getZeroAttr(resultType);
1624 constantMap.try_emplace(resultID, attr, resultType);
1628 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1636 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1637 if (
auto *block = getBlock(
id)) {
1638 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1639 <<
" @ " << block <<
"\n");
1646 auto *block = curFunction->addBlock();
1647 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1648 <<
" @ " << block <<
"\n");
1649 return blockMap[id] = block;
1654 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1657 if (operands.size() != 1) {
1658 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1661 auto *target = getOrCreateBlock(operands[0]);
1662 auto loc = createFileLineColLoc(opBuilder);
1666 opBuilder.create<spirv::BranchOp>(loc, target);
1676 "OpBranchConditional must appear inside a block");
1679 if (operands.size() != 3 && operands.size() != 5) {
1681 "OpBranchConditional must have condition, true label, "
1682 "false label, and optionally two branch weights");
1685 auto condition = getValue(operands[0]);
1686 auto *trueBlock = getOrCreateBlock(operands[1]);
1687 auto *falseBlock = getOrCreateBlock(operands[2]);
1689 std::optional<std::pair<uint32_t, uint32_t>> weights;
1690 if (operands.size() == 5) {
1691 weights = std::make_pair(operands[3], operands[4]);
1696 auto loc = createFileLineColLoc(opBuilder);
1697 opBuilder.create<spirv::BranchConditionalOp>(
1698 loc, condition, trueBlock,
1708 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1711 if (operands.size() != 1) {
1712 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1715 auto labelID = operands[0];
1717 auto *block = getOrCreateBlock(labelID);
1718 LLVM_DEBUG(logger.startLine()
1719 <<
"[block] populating block " << block <<
"\n");
1721 assert(block->
empty() &&
"re-deserialize the same block!");
1723 opBuilder.setInsertionPointToStart(block);
1724 blockMap[labelID] = curBlock = block;
1732 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1735 if (operands.size() < 2) {
1738 "OpSelectionMerge must specify merge target and selection control");
1741 auto *mergeBlock = getOrCreateBlock(operands[0]);
1742 auto loc = createFileLineColLoc(opBuilder);
1743 auto selectionControl = operands[1];
1745 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1749 "a block cannot have more than one OpSelectionMerge instruction");
1758 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1761 if (operands.size() < 3) {
1762 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1763 "continue target and loop control");
1766 auto *mergeBlock = getOrCreateBlock(operands[0]);
1767 auto *continueBlock = getOrCreateBlock(operands[1]);
1768 auto loc = createFileLineColLoc(opBuilder);
1769 uint32_t loopControl = operands[2];
1772 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1776 "a block cannot have more than one OpLoopMerge instruction");
1784 return emitError(unknownLoc,
"OpPhi must appear in a block");
1787 if (operands.size() < 4) {
1788 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1789 "and variable-parent pairs");
1794 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1795 valueMap[operands[1]] = blockArg;
1796 LLVM_DEBUG(logger.startLine()
1797 <<
"[phi] created block argument " << blockArg
1798 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1802 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1803 uint32_t value = operands[i];
1804 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1805 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1806 blockPhiInfo[predecessorTargetPair].
push_back(value);
1807 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1808 <<
" with arg id = " << value <<
"\n");
1817 class ControlFlowStructurizer {
1820 ControlFlowStructurizer(
Location loc, uint32_t control,
1823 llvm::ScopedPrinter &logger)
1824 : location(loc), control(control), blockMergeInfo(mergeInfo),
1825 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1828 ControlFlowStructurizer(
Location loc, uint32_t control,
1831 : location(loc), control(control), blockMergeInfo(mergeInfo),
1832 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1842 LogicalResult structurize();
1847 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1850 spirv::LoopOp createLoopOp(uint32_t loopControl);
1853 void collectBlocksInConstruct();
1862 Block *continueBlock;
1868 llvm::ScopedPrinter &logger;
1874 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1877 OpBuilder builder(&mergeBlock->front());
1879 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1880 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1881 selectionOp.addMergeBlock(builder);
1886 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1889 OpBuilder builder(&mergeBlock->front());
1891 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1892 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1893 loopOp.addEntryAndMergeBlock(builder);
1898 void ControlFlowStructurizer::collectBlocksInConstruct() {
1899 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1902 constructBlocks.insert(headerBlock);
1906 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1907 for (
auto *successor : constructBlocks[i]->getSuccessors())
1908 if (successor != mergeBlock)
1909 constructBlocks.insert(successor);
1913 LogicalResult ControlFlowStructurizer::structurize() {
1915 bool isLoop = continueBlock !=
nullptr;
1917 if (
auto loopOp = createLoopOp(control))
1918 op = loopOp.getOperation();
1920 if (
auto selectionOp = createSelectionOp(control))
1921 op = selectionOp.getOperation();
1930 mapper.
map(mergeBlock, &body.
back());
1932 collectBlocksInConstruct();
1954 for (
auto *block : constructBlocks) {
1957 auto *newBlock = builder.createBlock(&body.
back());
1958 mapper.
map(block, newBlock);
1959 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
1960 <<
" from block " << block <<
"\n");
1964 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1965 mapper.
map(blockArg, newArg);
1966 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
1967 << blockArg <<
" to " << newArg <<
"\n");
1970 LLVM_DEBUG(logger.startLine()
1971 <<
"[cf] block " << block <<
" is a function entry block\n");
1974 for (
auto &op : *block)
1975 newBlock->push_back(op.
clone(mapper));
1979 auto remapOperands = [&](
Operation *op) {
1982 operand.set(mappedOp);
1985 succOp.set(mappedOp);
1987 for (
auto &block : body)
1988 block.walk(remapOperands);
1996 headerBlock->replaceAllUsesWith(mergeBlock);
1999 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2000 headerBlock->getParentOp()->
print(logger.getOStream());
2001 logger.startLine() <<
"\n";
2005 if (!mergeBlock->args_empty()) {
2006 return mergeBlock->getParentOp()->emitError(
2007 "OpPhi in loop merge block unsupported");
2014 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2019 if (!headerBlock->args_empty())
2020 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2024 builder.setInsertionPointToEnd(&body.front());
2025 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
2053 body.back().addArgument(blockArg.
getType(), blockArg.
getLoc());
2054 valuesToYield.push_back(body.back().getArguments().back());
2055 outsideUses.push_back(blockArg);
2060 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2063 for (
auto *block : constructBlocks)
2064 block->dropAllReferences();
2069 for (
Block *block : constructBlocks) {
2074 outsideUses.push_back(result);
2078 if (!arg.use_empty()) {
2080 outsideUses.push_back(arg);
2085 assert(valuesToYield.size() == outsideUses.size());
2089 if (!valuesToYield.empty()) {
2090 LLVM_DEBUG(logger.startLine()
2091 <<
"[cf] yielding values from the selection / loop region\n");
2094 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2095 Operation *merge = llvm::getSingleElement(mergeOps);
2097 merge->setOperands(valuesToYield);
2105 builder.setInsertionPoint(&mergeBlock->front());
2110 newOp = builder.
create<spirv::LoopOp>(
2112 static_cast<spirv::LoopControl
>(control));
2114 newOp = builder.
create<spirv::SelectionOp>(
2116 static_cast<spirv::SelectionControl
>(control));
2126 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2127 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2133 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2140 for (
auto *block : constructBlocks) {
2143 return op.
emitOpError(
"failed control flow structurization: value has "
2144 "uses outside of the "
2145 "enclosing selection/loop construct");
2147 if (!arg.use_empty())
2148 return emitError(arg.getLoc(),
"failed control flow structurization: "
2149 "block argument has uses outside of the "
2150 "enclosing selection/loop construct");
2154 for (
auto *block : constructBlocks) {
2195 auto it = blockMergeInfo.find(block);
2196 if (it != blockMergeInfo.end()) {
2202 return emitError(loc,
"failed control flow structurization: nested "
2203 "loop header block should be remapped!");
2205 Block *newContinue = it->second.continueBlock;
2209 return emitError(loc,
"failed control flow structurization: nested "
2210 "loop continue block should be remapped!");
2213 Block *newMerge = it->second.mergeBlock;
2215 newMerge = mappedTo;
2219 blockMergeInfo.
erase(it);
2220 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2227 if (block->walk(updateMergeInfo).wasInterrupted())
2235 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2236 <<
" to only contain a spirv.Branch op\n");
2240 builder.setInsertionPointToEnd(block);
2241 builder.create<spirv::BranchOp>(location, mergeBlock);
2243 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2248 LLVM_DEBUG(logger.startLine()
2249 <<
"[cf] after structurizing construct with header block "
2250 << headerBlock <<
":\n"
2256 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2259 <<
"//----- [phi] start wiring up block arguments -----//\n";
2265 for (
const auto &info : blockPhiInfo) {
2266 Block *block = info.first.first;
2267 Block *target = info.first.second;
2268 const BlockPhiInfo &phiInfo = info.second;
2270 logger.startLine() <<
"[phi] block " << block <<
"\n";
2271 logger.startLine() <<
"[phi] before creating block argument:\n";
2273 logger.startLine() <<
"\n";
2279 opBuilder.setInsertionPoint(op);
2282 blockArgs.reserve(phiInfo.size());
2283 for (uint32_t valueId : phiInfo) {
2284 if (
Value value = getValue(valueId)) {
2285 blockArgs.push_back(value);
2286 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2287 <<
" id = " << valueId <<
"\n");
2289 return emitError(unknownLoc,
"OpPhi references undefined value!");
2293 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2295 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2298 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2299 assert((branchCondOp.getTrueBlock() == target ||
2300 branchCondOp.getFalseBlock() == target) &&
2301 "expected target to be either the true or false target");
2302 if (target == branchCondOp.getTrueTarget())
2303 opBuilder.create<spirv::BranchConditionalOp>(
2304 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2305 branchCondOp.getFalseBlockArguments(),
2306 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2307 branchCondOp.getFalseTarget());
2309 opBuilder.create<spirv::BranchConditionalOp>(
2310 branchCondOp.getLoc(), branchCondOp.getCondition(),
2311 branchCondOp.getTrueBlockArguments(), blockArgs,
2312 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2313 branchCondOp.getFalseBlock());
2315 branchCondOp.erase();
2317 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2321 logger.startLine() <<
"[phi] after creating block argument:\n";
2323 logger.startLine() <<
"\n";
2326 blockPhiInfo.clear();
2331 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2336 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2339 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2341 auto &[block, mergeInfo] = *it;
2344 if (mergeInfo.continueBlock)
2353 if (!isa<spirv::BranchConditionalOp>(terminator))
2357 bool splitHeaderMergeBlock =
false;
2358 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2359 if (mergeInfo.mergeBlock == block)
2360 splitHeaderMergeBlock =
true;
2367 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2370 builder.create<spirv::BranchOp>(block->
getParent()->
getLoc(), newBlock);
2374 blockMergeInfo.erase(block);
2375 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2382 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2383 if (!
options.enableControlFlowStructurization) {
2387 <<
"//----- [cf] skip structurizing control flow -----//\n";
2395 <<
"//----- [cf] start structurizing control flow -----//\n";
2400 logger.startLine() <<
"[cf] split conditional blocks\n";
2401 logger.startLine() <<
"\n";
2404 if (failed(splitConditionalBlocks())) {
2411 while (!blockMergeInfo.empty()) {
2412 Block *headerBlock = blockMergeInfo.
begin()->first;
2413 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2416 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2417 headerBlock->
print(logger.getOStream());
2418 logger.startLine() <<
"\n";
2421 auto *mergeBlock = mergeInfo.mergeBlock;
2422 assert(mergeBlock &&
"merge block cannot be nullptr");
2423 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2424 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2426 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2427 mergeBlock->print(logger.getOStream());
2428 logger.startLine() <<
"\n";
2431 auto *continueBlock = mergeInfo.continueBlock;
2432 LLVM_DEBUG(
if (continueBlock) {
2433 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2434 continueBlock->print(logger.getOStream());
2435 logger.startLine() <<
"\n";
2439 blockMergeInfo.erase(blockMergeInfo.begin());
2440 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2441 blockMergeInfo, headerBlock,
2442 mergeBlock, continueBlock
2448 if (failed(structurizer.structurize()))
2455 <<
"//--- [cf] completed structurizing control flow ---//\n";
2468 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2469 if (fileName.empty())
2470 fileName =
"<unknown>";
2482 if (operands.size() != 3)
2483 return emitError(unknownLoc,
"OpLine must have 3 operands");
2484 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2488 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2492 if (operands.size() < 2)
2493 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2495 if (!debugInfoMap.lookup(operands[0]).empty())
2497 "duplicate debug string found for result <id> ")
2500 unsigned wordIndex = 1;
2502 if (wordIndex != operands.size())
2504 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.