23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (failed(sliceInstruction(opcode, operands)))
82 if (failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117 LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158 if (operands.size() != 1)
159 return emitError(unknownLoc,
"OpCapability must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
163 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
178 if (wordIndex != words.size())
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
183 return emitError(unknownLoc,
"unknown extension: ") << extName;
185 extensions.insert(*ext);
191 if (words.size() < 2) {
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
199 if (wordIndex != words.size()) {
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
210 extensions.getArrayRef(), context));
215 if (operands.size() != 2)
216 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel
>(operands.front())));
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel
>(operands.back())));
230 template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235 if (words.size() != 4) {
236 return emitError(loc,
"OpDecoration with ")
237 << decorationName <<
"needs a cache control integer literal and a "
238 << cacheControlKind <<
" cache control literal";
240 unsigned cacheLevel = words[2];
241 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
242 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
246 llvm::append_range(attrs, attrList);
247 attrs.push_back(value);
248 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
256 if (words.size() < 2) {
258 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
260 auto decorationName =
261 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
262 if (decorationName.empty()) {
263 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
265 auto symbol = getSymbolDecoration(decorationName);
266 switch (
static_cast<spirv::Decoration
>(words[1])) {
267 case spirv::Decoration::FPFastMathMode:
268 if (words.size() != 3) {
269 return emitError(unknownLoc,
"OpDecorate with ")
270 << decorationName <<
" needs a single integer literal";
272 decorations[words[0]].set(
274 static_cast<FPFastMathMode
>(words[2])));
276 case spirv::Decoration::FPRoundingMode:
277 if (words.size() != 3) {
278 return emitError(unknownLoc,
"OpDecorate with ")
279 << decorationName <<
" needs a single integer literal";
281 decorations[words[0]].set(
283 static_cast<FPRoundingMode
>(words[2])));
285 case spirv::Decoration::DescriptorSet:
286 case spirv::Decoration::Binding:
287 if (words.size() != 3) {
288 return emitError(unknownLoc,
"OpDecorate with ")
289 << decorationName <<
" needs a single integer literal";
291 decorations[words[0]].set(
292 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
294 case spirv::Decoration::BuiltIn:
295 if (words.size() != 3) {
296 return emitError(unknownLoc,
"OpDecorate with ")
297 << decorationName <<
" needs a single integer literal";
299 decorations[words[0]].set(
300 symbol, opBuilder.getStringAttr(
301 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
303 case spirv::Decoration::ArrayStride:
304 if (words.size() != 3) {
305 return emitError(unknownLoc,
"OpDecorate with ")
306 << decorationName <<
" needs a single integer literal";
308 typeDecorations[words[0]] = words[2];
310 case spirv::Decoration::LinkageAttributes: {
311 if (words.size() < 4) {
312 return emitError(unknownLoc,
"OpDecorate with ")
314 <<
" needs at least 1 string and 1 integer literal";
322 unsigned wordIndex = 2;
324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
331 case spirv::Decoration::Aliased:
332 case spirv::Decoration::AliasedPointer:
333 case spirv::Decoration::Block:
334 case spirv::Decoration::BufferBlock:
335 case spirv::Decoration::Flat:
336 case spirv::Decoration::NonReadable:
337 case spirv::Decoration::NonWritable:
338 case spirv::Decoration::NoPerspective:
339 case spirv::Decoration::NoSignedWrap:
340 case spirv::Decoration::NoUnsignedWrap:
341 case spirv::Decoration::RelaxedPrecision:
342 case spirv::Decoration::Restrict:
343 case spirv::Decoration::RestrictPointer:
344 case spirv::Decoration::NoContraction:
345 case spirv::Decoration::Constant:
346 if (words.size() != 2) {
347 return emitError(unknownLoc,
"OpDecoration with ")
348 << decorationName <<
"needs a single target <id>";
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(unknownLoc,
"OpDecoration with ")
360 << decorationName <<
"needs a single integer literal";
362 decorations[words[0]].set(
363 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
365 case spirv::Decoration::CacheControlLoadINTEL: {
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
374 case spirv::Decoration::CacheControlStoreINTEL: {
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
384 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
392 if (words.size() < 3) {
394 "OpMemberDecorate must have at least 3 operands");
397 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
400 " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
404 if (words.size() > 3) {
405 decorationOperands = words.slice(3);
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
412 if (words.size() < 3) {
413 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
415 unsigned wordIndex = 2;
417 if (wordIndex != words.size()) {
419 "unexpected trailing words in OpMemberName instruction");
421 memberNameMap[words[0]][words[1]] = name;
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
427 if (!decorations.contains(argID)) {
432 spirv::DecorationAttr foundDecorationAttr;
434 for (
auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
439 if (decAttr.getName() !=
440 getSymbolDecoration(stringifyDecoration(decoration)))
443 if (foundDecorationAttr)
445 "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
453 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454 spirv::Decoration::RelaxedPrecision))) {
459 if (foundDecorationAttr)
460 return emitError(unknownLoc,
"already found a decoration for function "
461 "argument with result <id> ")
465 context, spirv::Decoration::RelaxedPrecision);
469 if (!foundDecorationAttr)
470 return emitError(unknownLoc,
"unimplemented decoration support for "
471 "function argument with result <id> ")
475 foundDecorationAttr);
483 return emitError(unknownLoc,
"found function inside function");
487 if (operands.size() != 4) {
488 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
492 return emitError(unknownLoc,
"undefined result type from <id> ")
496 uint32_t fnID = operands[1];
497 if (funcMap.count(fnID)) {
498 return emitError(unknownLoc,
"duplicate function definition/declaration");
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
503 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
507 if (!fnType || !isa<FunctionType>(fnType)) {
508 return emitError(unknownLoc,
"unknown function type from <id> ")
511 auto functionType = cast<FunctionType>(fnType);
513 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(0) != resultType)) {
516 return emitError(unknownLoc,
"mismatch in function type ")
517 << functionType <<
" and return type " << resultType <<
" specified";
520 std::string fnName = getFunctionSymbol(fnID);
521 auto funcOp = opBuilder.create<spirv::FuncOp>(
522 unknownLoc, fnName, functionType, fnControl.value());
524 if (decorations.count(fnID)) {
525 for (
auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(attr.getName(), attr.getValue());
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
533 <<
"//===-------------------------------------------===//\n";
534 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
535 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
536 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
537 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
542 argAttrs.resize(functionType.getNumInputs());
545 if (functionType.getNumInputs()) {
546 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
550 if (failed(sliceInstruction(opcode, operands,
551 spirv::Opcode::OpFunctionParameter))) {
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
557 "missing OpFunctionParameter instruction for argument ")
560 if (operands.size() != 2) {
563 "expected result type and result <id> for OpFunctionParameter");
565 auto argDefinedType =
getType(operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
568 "mismatch in argument type between function type "
570 << functionType <<
" and argument type definition "
571 << argDefinedType <<
" at argument " << i;
573 if (getValue(operands[1])) {
574 return emitError(unknownLoc,
"duplicate definition of result <id> ")
577 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
581 auto argValue = funcOp.getArgument(i);
582 valueMap[operands[1]] = argValue;
586 if (llvm::any_of(argAttrs, [](
Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(attr);
588 return !argAttr.empty();
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
614 if (failed(sliceInstruction(opcode, instOperands,
615 spirv::Opcode::OpFunctionEnd))) {
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
619 return processFunctionEnd(instOperands);
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(unknownLoc,
"a basic block must start with OpLabel");
624 if (instOperands.size() != 1) {
625 return emitError(unknownLoc,
"OpLabel should only have result <id>");
627 blockMap[instOperands[0]] = entryBlock;
628 if (failed(processLabel(instOperands))) {
634 while (succeeded(sliceInstruction(opcode, instOperands,
635 spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
637 if (failed(processInstruction(opcode, instOperands))) {
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
645 return processFunctionEnd(instOperands);
651 if (!operands.empty()) {
652 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
658 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
663 curFunction = std::nullopt;
668 <<
"//===-------------------------------------------===//\n";
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t
id) {
675 auto constIt = constantMap.find(
id);
676 if (constIt == constantMap.end())
678 return constIt->getSecond();
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
683 auto constIt = specConstOperationMap.find(
id);
684 if (constIt == specConstOperationMap.end())
686 return constIt->getSecond();
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
690 auto funcName = nameMap.lookup(
id).str();
691 if (funcName.empty()) {
692 funcName =
"spirv_fn_" + std::to_string(
id);
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
698 auto constName = nameMap.lookup(
id).str();
699 if (constName.empty()) {
700 constName =
"spirv_spec_const_" + std::to_string(
id);
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
707 TypedAttr defaultValue) {
708 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
711 if (decorations.count(resultID)) {
712 for (
auto attr : decorations[resultID].getAttrs())
713 op->setAttr(attr.getName(), attr.getValue());
715 specConstMap[resultID] = op;
721 unsigned wordIndex = 0;
722 if (operands.size() < 3) {
725 "OpVariable needs at least 3 operands, type, <id> and storage class");
729 auto type =
getType(operands[wordIndex]);
731 return emitError(unknownLoc,
"unknown result type <id> : ")
732 << operands[wordIndex];
734 auto ptrType = dyn_cast<spirv::PointerType>(type);
737 "expected a result type <id> to be a spirv.ptr, found : ")
743 auto variableID = operands[wordIndex];
744 auto variableName = nameMap.lookup(variableID).str();
745 if (variableName.empty()) {
746 variableName =
"spirv_var_" + std::to_string(variableID);
751 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
752 if (ptrType.getStorageClass() != storageClass) {
753 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
754 << type <<
" and that specified in OpVariable instruction : "
755 << stringifyStorageClass(storageClass);
762 if (wordIndex < operands.size()) {
765 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
767 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
769 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
772 return emitError(unknownLoc,
"unknown <id> ")
773 << operands[wordIndex] <<
"used as initializer";
778 if (wordIndex != operands.size()) {
780 "found more operands than expected when deserializing "
781 "OpVariable instruction, only ")
782 << wordIndex <<
" of " << operands.size() <<
" processed";
784 auto loc = createFileLineColLoc(opBuilder);
785 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
790 if (decorations.count(variableID)) {
791 for (
auto attr : decorations[variableID].getAttrs())
792 varOp->setAttr(attr.getName(), attr.getValue());
794 globalVariableMap[variableID] = varOp;
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
799 auto constInfo = getConstant(
id);
803 return dyn_cast<IntegerAttr>(constInfo->first);
807 if (operands.size() < 2) {
808 return emitError(unknownLoc,
"OpName needs at least 2 operands");
810 if (!nameMap.lookup(operands[0]).empty()) {
811 return emitError(unknownLoc,
"duplicate name found for result <id> ")
814 unsigned wordIndex = 1;
816 if (wordIndex != operands.size()) {
818 "unexpected trailing words in OpName instruction");
820 nameMap[operands[0]] = name;
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
830 if (operands.empty()) {
831 return emitError(unknownLoc,
"type instruction with opcode ")
832 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
837 if (typeMap.count(operands[0])) {
838 return emitError(unknownLoc,
"duplicate definition for result <id> ")
843 case spirv::Opcode::OpTypeVoid:
844 if (operands.size() != 1)
845 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
846 typeMap[operands[0]] = opBuilder.getNoneType();
848 case spirv::Opcode::OpTypeBool:
849 if (operands.size() != 1)
850 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
851 typeMap[operands[0]] = opBuilder.getI1Type();
853 case spirv::Opcode::OpTypeInt: {
854 if (operands.size() != 3)
856 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
866 : IntegerType::SignednessSemantics::Signless;
869 case spirv::Opcode::OpTypeFloat: {
870 if (operands.size() != 2 && operands.size() != 3)
872 "OpTypeFloat expects either 2 operands (type, bitwidth) "
873 "or 3 operands (type, bitwidth, encoding), but got ")
875 uint32_t bitWidth = operands[1];
880 floatTy = opBuilder.getF16Type();
883 floatTy = opBuilder.getF32Type();
886 floatTy = opBuilder.getF64Type();
889 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
893 if (operands.size() == 3) {
894 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
895 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
899 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
900 << bitWidth <<
" (expected 16)";
901 floatTy = opBuilder.getBF16Type();
904 typeMap[operands[0]] = floatTy;
906 case spirv::Opcode::OpTypeVector: {
907 if (operands.size() != 3) {
910 "OpTypeVector must have element type and count parameters");
914 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
919 case spirv::Opcode::OpTypePointer: {
920 return processOpTypePointer(operands);
922 case spirv::Opcode::OpTypeArray:
923 return processArrayType(operands);
924 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
925 return processCooperativeMatrixTypeKHR(operands);
926 case spirv::Opcode::OpTypeFunction:
927 return processFunctionType(operands);
928 case spirv::Opcode::OpTypeImage:
929 return processImageType(operands);
930 case spirv::Opcode::OpTypeSampledImage:
931 return processSampledImageType(operands);
932 case spirv::Opcode::OpTypeRuntimeArray:
933 return processRuntimeArrayType(operands);
934 case spirv::Opcode::OpTypeStruct:
935 return processStructType(operands);
936 case spirv::Opcode::OpTypeMatrix:
937 return processMatrixType(operands);
939 return emitError(unknownLoc,
"unhandled type instruction");
946 if (operands.size() != 3)
947 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
949 auto pointeeType =
getType(operands[2]);
951 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
954 uint32_t typePointerID = operands[0];
955 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
958 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
959 deferredStructIt != std::end(deferredStructTypesInfos);) {
960 for (
auto *unresolvedMemberIt =
961 std::begin(deferredStructIt->unresolvedMemberTypes);
962 unresolvedMemberIt !=
963 std::end(deferredStructIt->unresolvedMemberTypes);) {
964 if (unresolvedMemberIt->first == typePointerID) {
968 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
969 typeMap[typePointerID];
971 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
973 ++unresolvedMemberIt;
977 if (deferredStructIt->unresolvedMemberTypes.empty()) {
979 auto structType = deferredStructIt->deferredStructType;
981 assert(structType &&
"expected a spirv::StructType");
982 assert(structType.isIdentified() &&
"expected an indentified struct");
984 if (failed(structType.trySetBody(
985 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
986 deferredStructIt->memberDecorationsInfo)))
989 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1000 if (operands.size() != 3) {
1002 "OpTypeArray must have element type and count parameters");
1007 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1013 auto countInfo = getConstant(operands[2]);
1015 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1016 << operands[2] <<
"can only come from normal constant right now";
1019 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1020 count = intVal.getValue().getZExtValue();
1022 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1023 "scalar integer constant instruction");
1027 elementTy, count, typeDecorations.lookup(operands[0]));
1033 assert(!operands.empty() &&
"No operands for processing function type");
1034 if (operands.size() == 1) {
1035 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1037 auto returnType =
getType(operands[1]);
1039 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1042 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1043 auto ty =
getType(operands[i]);
1045 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1047 argTypes.push_back(ty);
1050 if (!isVoidType(returnType)) {
1057 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1059 if (operands.size() != 6) {
1061 "OpTypeCooperativeMatrixKHR must have element type, "
1062 "scope, row and column parameters, and use");
1068 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1072 std::optional<spirv::Scope> scope =
1073 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1077 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1081 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1082 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1083 IntegerAttr useAttr = getConstantInt(operands[5]);
1086 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1087 "undefined constant <id> ")
1091 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1092 "references undefined constant <id> ")
1096 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1097 "undefined constant <id> ")
1100 unsigned rows = rowsAttr.getInt();
1101 unsigned columns = columnsAttr.getInt();
1103 std::optional<spirv::CooperativeMatrixUseKHR> use =
1104 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1108 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1112 typeMap[operands[0]] =
1119 if (operands.size() != 2) {
1120 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1125 "OpTypeRuntimeArray references undefined <id> ")
1129 memberType, typeDecorations.lookup(operands[0]));
1137 if (operands.empty()) {
1138 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1141 if (operands.size() == 1) {
1143 typeMap[operands[0]] =
1152 for (
auto op : llvm::drop_begin(operands, 1)) {
1154 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1156 if (!memberType && !typeForwardPtr)
1157 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1161 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1163 memberTypes.push_back(memberType);
1168 if (memberDecorationMap.count(operands[0])) {
1169 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1170 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1171 if (allMemberDecorations.count(memberIndex)) {
1172 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1174 if (memberDecoration.first == spirv::Decoration::Offset) {
1176 if (offsetInfo.empty()) {
1177 offsetInfo.resize(memberTypes.size());
1179 offsetInfo[memberIndex] = memberDecoration.second[0];
1181 if (!memberDecoration.second.empty()) {
1182 memberDecorationsInfo.emplace_back(memberIndex, 1,
1183 memberDecoration.first,
1184 memberDecoration.second[0]);
1186 memberDecorationsInfo.emplace_back(memberIndex, 0,
1187 memberDecoration.first, 0);
1195 uint32_t structID = operands[0];
1196 std::string structIdentifier = nameMap.lookup(structID).str();
1198 if (structIdentifier.empty()) {
1199 assert(unresolvedMemberTypes.empty() &&
1200 "didn't expect unresolved member types");
1205 typeMap[structID] = structTy;
1207 if (!unresolvedMemberTypes.empty())
1208 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1209 memberTypes, offsetInfo,
1210 memberDecorationsInfo});
1211 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1212 memberDecorationsInfo)))
1223 if (operands.size() != 3) {
1225 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1226 " (result_id, column_type, and column_count)");
1232 "OpTypeMatrix references undefined column type.")
1236 uint32_t colsCount = operands[2];
1243 if (operands.size() != 2)
1245 "OpTypeForwardPointer instruction must have two operands");
1247 typeForwardPointerIDs.insert(operands[0]);
1257 if (operands.size() != 8)
1260 "OpTypeImage with non-eight operands are not supported yet");
1264 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1267 auto dim = spirv::symbolizeDim(operands[2]);
1269 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1272 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1274 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1277 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1279 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1282 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1284 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1286 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1287 if (!samplerUseInfo)
1288 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1291 auto format = spirv::symbolizeImageFormat(operands[7]);
1293 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1297 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1298 samplingInfo.value(), samplerUseInfo.value(), format.value());
1304 if (operands.size() != 2)
1305 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1310 "OpTypeSampledImage references undefined <id>: ")
1323 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1325 if (operands.size() < 2) {
1327 << opname <<
" must have type <id> and result <id>";
1329 if (operands.size() < 3) {
1331 << opname <<
" must have at least 1 more parameter";
1336 return emitError(unknownLoc,
"undefined result type from <id> ")
1340 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1341 if (bitwidth == 64) {
1342 if (operands.size() == 4) {
1346 << opname <<
" should have 2 parameters for 64-bit values";
1348 if (bitwidth <= 32) {
1349 if (operands.size() == 3) {
1355 <<
" should have 1 parameter for values with no more than 32 bits";
1357 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1361 auto resultID = operands[1];
1363 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1364 auto bitwidth = intType.getWidth();
1365 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1370 if (bitwidth == 64) {
1377 } words = {operands[2], operands[3]};
1378 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1379 }
else if (bitwidth <= 32) {
1380 value = APInt(bitwidth, operands[2],
true,
1384 auto attr = opBuilder.getIntegerAttr(intType, value);
1387 createSpecConstant(unknownLoc, resultID, attr);
1391 constantMap.try_emplace(resultID, attr, intType);
1397 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1398 auto bitwidth = floatType.getWidth();
1399 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1404 if (floatType.isF64()) {
1411 } words = {operands[2], operands[3]};
1412 value = APFloat(llvm::bit_cast<double>(words));
1413 }
else if (floatType.isF32()) {
1414 value = APFloat(llvm::bit_cast<float>(operands[2]));
1415 }
else if (floatType.isF16()) {
1416 APInt data(16, operands[2]);
1417 value = APFloat(APFloat::IEEEhalf(), data);
1418 }
else if (floatType.isBF16()) {
1419 APInt data(16, operands[2]);
1420 value = APFloat(APFloat::BFloat(), data);
1423 auto attr = opBuilder.getFloatAttr(floatType, value);
1425 createSpecConstant(unknownLoc, resultID, attr);
1429 constantMap.try_emplace(resultID, attr, floatType);
1435 return emitError(unknownLoc,
"OpConstant can only generate values of "
1436 "scalar integer or floating-point type");
1439 LogicalResult spirv::Deserializer::processConstantBool(
1441 if (operands.size() != 2) {
1443 << (isSpec ?
"Spec" :
"") <<
"Constant"
1444 << (isTrue ?
"True" :
"False")
1445 <<
" must have type <id> and result <id>";
1448 auto attr = opBuilder.getBoolAttr(isTrue);
1449 auto resultID = operands[1];
1451 createSpecConstant(unknownLoc, resultID, attr);
1455 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1463 if (operands.size() < 2) {
1465 "OpConstantComposite must have type <id> and result <id>");
1467 if (operands.size() < 3) {
1469 "OpConstantComposite must have at least 1 parameter");
1474 return emitError(unknownLoc,
"undefined result type from <id> ")
1479 elements.reserve(operands.size() - 2);
1480 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1481 auto elementInfo = getConstant(operands[i]);
1483 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1484 << operands[i] <<
" must come from a normal constant";
1486 elements.push_back(elementInfo->first);
1489 auto resultID = operands[1];
1490 if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1494 constantMap.try_emplace(resultID, attr, shapedType);
1495 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1496 auto attr = opBuilder.getArrayAttr(elements);
1497 constantMap.try_emplace(resultID, attr, resultType);
1499 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1508 if (operands.size() < 2) {
1510 "OpConstantComposite must have type <id> and result <id>");
1512 if (operands.size() < 3) {
1514 "OpConstantComposite must have at least 1 parameter");
1519 return emitError(unknownLoc,
"undefined result type from <id> ")
1523 auto resultID = operands[1];
1524 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1527 elements.reserve(operands.size() - 2);
1528 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1529 auto elementInfo = getSpecConstant(operands[i]);
1533 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1535 opBuilder.getArrayAttr(elements));
1536 specConstCompositeMap[resultID] = op;
1543 if (operands.size() < 3)
1544 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1545 "result <id>, and operand opcode");
1547 uint32_t resultTypeID = operands[0];
1550 return emitError(unknownLoc,
"undefined result type from <id> ")
1553 uint32_t resultID = operands[1];
1554 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1555 auto emplaceResult = specConstOperationMap.try_emplace(
1557 SpecConstOperationMaterializationInfo{
1558 enclosedOpcode, resultTypeID,
1561 if (!emplaceResult.second)
1562 return emitError(unknownLoc,
"value with <id>: ")
1563 << resultID <<
" is probably defined before.";
1568 Value spirv::Deserializer::materializeSpecConstantOperation(
1569 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1585 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1586 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1589 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1590 enclosedOpResultTypeAndOperands.push_back(fakeID);
1591 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1592 enclosedOpOperands.end());
1599 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1606 auto loc = createFileLineColLoc(opBuilder);
1607 auto specConstOperationOp =
1608 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1610 Region &body = specConstOperationOp.getBody();
1612 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1619 opBuilder.setInsertionPointToEnd(&block);
1621 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1622 return specConstOperationOp.getResult();
1627 if (operands.size() != 2) {
1629 "OpConstantNull must have type <id> and result <id>");
1634 return emitError(unknownLoc,
"undefined result type from <id> ")
1638 auto resultID = operands[1];
1639 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1640 auto attr = opBuilder.getZeroAttr(resultType);
1643 constantMap.try_emplace(resultID, attr, resultType);
1647 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1655 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1656 if (
auto *block = getBlock(
id)) {
1657 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1658 <<
" @ " << block <<
"\n");
1665 auto *block = curFunction->addBlock();
1666 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1667 <<
" @ " << block <<
"\n");
1668 return blockMap[id] = block;
1673 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1676 if (operands.size() != 1) {
1677 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1680 auto *target = getOrCreateBlock(operands[0]);
1681 auto loc = createFileLineColLoc(opBuilder);
1685 opBuilder.create<spirv::BranchOp>(loc, target);
1695 "OpBranchConditional must appear inside a block");
1698 if (operands.size() != 3 && operands.size() != 5) {
1700 "OpBranchConditional must have condition, true label, "
1701 "false label, and optionally two branch weights");
1704 auto condition = getValue(operands[0]);
1705 auto *trueBlock = getOrCreateBlock(operands[1]);
1706 auto *falseBlock = getOrCreateBlock(operands[2]);
1708 std::optional<std::pair<uint32_t, uint32_t>> weights;
1709 if (operands.size() == 5) {
1710 weights = std::make_pair(operands[3], operands[4]);
1715 auto loc = createFileLineColLoc(opBuilder);
1716 opBuilder.create<spirv::BranchConditionalOp>(
1717 loc, condition, trueBlock,
1727 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1730 if (operands.size() != 1) {
1731 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1734 auto labelID = operands[0];
1736 auto *block = getOrCreateBlock(labelID);
1737 LLVM_DEBUG(logger.startLine()
1738 <<
"[block] populating block " << block <<
"\n");
1740 assert(block->
empty() &&
"re-deserialize the same block!");
1742 opBuilder.setInsertionPointToStart(block);
1743 blockMap[labelID] = curBlock = block;
1751 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1754 if (operands.size() < 2) {
1757 "OpSelectionMerge must specify merge target and selection control");
1760 auto *mergeBlock = getOrCreateBlock(operands[0]);
1761 auto loc = createFileLineColLoc(opBuilder);
1762 auto selectionControl = operands[1];
1764 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1768 "a block cannot have more than one OpSelectionMerge instruction");
1777 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1780 if (operands.size() < 3) {
1781 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1782 "continue target and loop control");
1785 auto *mergeBlock = getOrCreateBlock(operands[0]);
1786 auto *continueBlock = getOrCreateBlock(operands[1]);
1787 auto loc = createFileLineColLoc(opBuilder);
1788 uint32_t loopControl = operands[2];
1791 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1795 "a block cannot have more than one OpLoopMerge instruction");
1803 return emitError(unknownLoc,
"OpPhi must appear in a block");
1806 if (operands.size() < 4) {
1807 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1808 "and variable-parent pairs");
1813 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1814 valueMap[operands[1]] = blockArg;
1815 LLVM_DEBUG(logger.startLine()
1816 <<
"[phi] created block argument " << blockArg
1817 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1821 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1822 uint32_t value = operands[i];
1823 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1824 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1825 blockPhiInfo[predecessorTargetPair].
push_back(value);
1826 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1827 <<
" with arg id = " << value <<
"\n");
1836 class ControlFlowStructurizer {
1839 ControlFlowStructurizer(
Location loc, uint32_t control,
1842 llvm::ScopedPrinter &logger)
1843 : location(loc), control(control), blockMergeInfo(mergeInfo),
1844 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1847 ControlFlowStructurizer(
Location loc, uint32_t control,
1850 : location(loc), control(control), blockMergeInfo(mergeInfo),
1851 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1861 LogicalResult structurize();
1866 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1869 spirv::LoopOp createLoopOp(uint32_t loopControl);
1872 void collectBlocksInConstruct();
1881 Block *continueBlock;
1887 llvm::ScopedPrinter &logger;
1893 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1896 OpBuilder builder(&mergeBlock->front());
1898 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1899 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1900 selectionOp.addMergeBlock(builder);
1905 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1908 OpBuilder builder(&mergeBlock->front());
1910 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1911 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1912 loopOp.addEntryAndMergeBlock(builder);
1917 void ControlFlowStructurizer::collectBlocksInConstruct() {
1918 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1921 constructBlocks.insert(headerBlock);
1925 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1926 for (
auto *successor : constructBlocks[i]->getSuccessors())
1927 if (successor != mergeBlock)
1928 constructBlocks.insert(successor);
1932 LogicalResult ControlFlowStructurizer::structurize() {
1934 bool isLoop = continueBlock !=
nullptr;
1936 if (
auto loopOp = createLoopOp(control))
1937 op = loopOp.getOperation();
1939 if (
auto selectionOp = createSelectionOp(control))
1940 op = selectionOp.getOperation();
1949 mapper.
map(mergeBlock, &body.
back());
1951 collectBlocksInConstruct();
1973 for (
auto *block : constructBlocks) {
1976 auto *newBlock = builder.createBlock(&body.
back());
1977 mapper.
map(block, newBlock);
1978 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
1979 <<
" from block " << block <<
"\n");
1983 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1984 mapper.
map(blockArg, newArg);
1985 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
1986 << blockArg <<
" to " << newArg <<
"\n");
1989 LLVM_DEBUG(logger.startLine()
1990 <<
"[cf] block " << block <<
" is a function entry block\n");
1993 for (
auto &op : *block)
1994 newBlock->push_back(op.
clone(mapper));
1998 auto remapOperands = [&](
Operation *op) {
2001 operand.set(mappedOp);
2004 succOp.set(mappedOp);
2006 for (
auto &block : body)
2007 block.walk(remapOperands);
2015 headerBlock->replaceAllUsesWith(mergeBlock);
2018 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2019 headerBlock->getParentOp()->
print(logger.getOStream());
2020 logger.startLine() <<
"\n";
2024 if (!mergeBlock->args_empty()) {
2025 return mergeBlock->getParentOp()->emitError(
2026 "OpPhi in loop merge block unsupported");
2033 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2038 if (!headerBlock->args_empty())
2039 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2043 builder.setInsertionPointToEnd(&body.front());
2044 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
2072 body.back().addArgument(blockArg.
getType(), blockArg.
getLoc());
2073 valuesToYield.push_back(body.back().getArguments().back());
2074 outsideUses.push_back(blockArg);
2079 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2082 for (
auto *block : constructBlocks)
2083 block->dropAllReferences();
2088 for (
Block *block : constructBlocks) {
2093 outsideUses.push_back(result);
2097 if (!arg.use_empty()) {
2099 outsideUses.push_back(arg);
2104 assert(valuesToYield.size() == outsideUses.size());
2108 if (!valuesToYield.empty()) {
2109 LLVM_DEBUG(logger.startLine()
2110 <<
"[cf] yielding values from the selection / loop region\n");
2113 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2114 Operation *merge = llvm::getSingleElement(mergeOps);
2116 merge->setOperands(valuesToYield);
2124 builder.setInsertionPoint(&mergeBlock->front());
2129 newOp = builder.
create<spirv::LoopOp>(
2131 static_cast<spirv::LoopControl
>(control));
2133 newOp = builder.
create<spirv::SelectionOp>(
2135 static_cast<spirv::SelectionControl
>(control));
2145 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2146 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2152 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2159 for (
auto *block : constructBlocks) {
2162 return op.
emitOpError(
"failed control flow structurization: value has "
2163 "uses outside of the "
2164 "enclosing selection/loop construct");
2166 if (!arg.use_empty())
2167 return emitError(arg.getLoc(),
"failed control flow structurization: "
2168 "block argument has uses outside of the "
2169 "enclosing selection/loop construct");
2173 for (
auto *block : constructBlocks) {
2214 auto it = blockMergeInfo.find(block);
2215 if (it != blockMergeInfo.end()) {
2221 return emitError(loc,
"failed control flow structurization: nested "
2222 "loop header block should be remapped!");
2224 Block *newContinue = it->second.continueBlock;
2228 return emitError(loc,
"failed control flow structurization: nested "
2229 "loop continue block should be remapped!");
2232 Block *newMerge = it->second.mergeBlock;
2234 newMerge = mappedTo;
2238 blockMergeInfo.
erase(it);
2239 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2246 if (block->walk(updateMergeInfo).wasInterrupted())
2254 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2255 <<
" to only contain a spirv.Branch op\n");
2259 builder.setInsertionPointToEnd(block);
2260 builder.create<spirv::BranchOp>(location, mergeBlock);
2262 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2267 LLVM_DEBUG(logger.startLine()
2268 <<
"[cf] after structurizing construct with header block "
2269 << headerBlock <<
":\n"
2275 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2278 <<
"//----- [phi] start wiring up block arguments -----//\n";
2284 for (
const auto &info : blockPhiInfo) {
2285 Block *block = info.first.first;
2286 Block *target = info.first.second;
2287 const BlockPhiInfo &phiInfo = info.second;
2289 logger.startLine() <<
"[phi] block " << block <<
"\n";
2290 logger.startLine() <<
"[phi] before creating block argument:\n";
2292 logger.startLine() <<
"\n";
2298 opBuilder.setInsertionPoint(op);
2301 blockArgs.reserve(phiInfo.size());
2302 for (uint32_t valueId : phiInfo) {
2303 if (
Value value = getValue(valueId)) {
2304 blockArgs.push_back(value);
2305 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2306 <<
" id = " << valueId <<
"\n");
2308 return emitError(unknownLoc,
"OpPhi references undefined value!");
2312 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2314 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2317 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2318 assert((branchCondOp.getTrueBlock() == target ||
2319 branchCondOp.getFalseBlock() == target) &&
2320 "expected target to be either the true or false target");
2321 if (target == branchCondOp.getTrueTarget())
2322 opBuilder.create<spirv::BranchConditionalOp>(
2323 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2324 branchCondOp.getFalseBlockArguments(),
2325 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2326 branchCondOp.getFalseTarget());
2328 opBuilder.create<spirv::BranchConditionalOp>(
2329 branchCondOp.getLoc(), branchCondOp.getCondition(),
2330 branchCondOp.getTrueBlockArguments(), blockArgs,
2331 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2332 branchCondOp.getFalseBlock());
2334 branchCondOp.erase();
2336 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2340 logger.startLine() <<
"[phi] after creating block argument:\n";
2342 logger.startLine() <<
"\n";
2345 blockPhiInfo.clear();
2350 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2355 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2358 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2360 auto &[block, mergeInfo] = *it;
2363 if (mergeInfo.continueBlock)
2372 if (!isa<spirv::BranchConditionalOp>(terminator))
2376 bool splitHeaderMergeBlock =
false;
2377 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2378 if (mergeInfo.mergeBlock == block)
2379 splitHeaderMergeBlock =
true;
2386 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2389 builder.create<spirv::BranchOp>(block->
getParent()->
getLoc(), newBlock);
2393 blockMergeInfo.erase(block);
2394 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2401 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2402 if (!
options.enableControlFlowStructurization) {
2406 <<
"//----- [cf] skip structurizing control flow -----//\n";
2414 <<
"//----- [cf] start structurizing control flow -----//\n";
2419 logger.startLine() <<
"[cf] split conditional blocks\n";
2420 logger.startLine() <<
"\n";
2423 if (failed(splitConditionalBlocks())) {
2430 while (!blockMergeInfo.empty()) {
2431 Block *headerBlock = blockMergeInfo.
begin()->first;
2432 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2435 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2436 headerBlock->
print(logger.getOStream());
2437 logger.startLine() <<
"\n";
2440 auto *mergeBlock = mergeInfo.mergeBlock;
2441 assert(mergeBlock &&
"merge block cannot be nullptr");
2442 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2443 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2445 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2446 mergeBlock->print(logger.getOStream());
2447 logger.startLine() <<
"\n";
2450 auto *continueBlock = mergeInfo.continueBlock;
2451 LLVM_DEBUG(
if (continueBlock) {
2452 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2453 continueBlock->print(logger.getOStream());
2454 logger.startLine() <<
"\n";
2458 blockMergeInfo.erase(blockMergeInfo.begin());
2459 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2460 blockMergeInfo, headerBlock,
2461 mergeBlock, continueBlock
2467 if (failed(structurizer.structurize()))
2474 <<
"//--- [cf] completed structurizing control flow ---//\n";
2487 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2488 if (fileName.empty())
2489 fileName =
"<unknown>";
2501 if (operands.size() != 3)
2502 return emitError(unknownLoc,
"OpLine must have 3 operands");
2503 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2507 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2511 if (operands.size() < 2)
2512 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2514 if (!debugInfoMap.lookup(operands[0]).empty())
2516 "duplicate debug string found for result <id> ")
2519 unsigned wordIndex = 1;
2521 if (wordIndex != operands.size())
2523 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.