23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (failed(sliceInstruction(opcode, operands)))
82 if (failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117 LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142 #undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207 void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231 template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecoration with ")
238 << decorationName <<
"needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc,
"OpDecorate with ")
290 << decorationName <<
" needs a single integer literal";
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc,
"OpDecorate with ")
298 << decorationName <<
" needs a single integer literal";
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc,
"OpDecorate with ")
307 << decorationName <<
" needs a single integer literal";
309 typeDecorations[words[0]] = words[2];
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc,
"OpDecorate with ")
315 <<
" needs at least 1 string and 1 integer literal";
323 unsigned wordIndex = 2;
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
329 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 if (words.size() != 2) {
350 return emitError(unknownLoc,
"OpDecoration with ")
351 << decorationName <<
"needs a single target <id>";
353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355 case spirv::Decoration::Location:
356 case spirv::Decoration::SpecId:
357 if (words.size() != 3) {
358 return emitError(unknownLoc,
"OpDecoration with ")
359 << decorationName <<
"needs a single integer literal";
361 decorations[words[0]].set(
362 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
364 case spirv::Decoration::CacheControlLoadINTEL: {
366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
373 case spirv::Decoration::CacheControlStoreINTEL: {
375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
383 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
391 if (words.size() < 3) {
393 "OpMemberDecorate must have at least 3 operands");
396 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
397 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
403 if (words.size() > 3) {
404 decorationOperands = words.slice(3);
406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
411 if (words.size() < 3) {
412 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
414 unsigned wordIndex = 2;
416 if (wordIndex != words.size()) {
418 "unexpected trailing words in OpMemberName instruction");
420 memberNameMap[words[0]][words[1]] = name;
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426 if (!decorations.contains(argID)) {
431 spirv::DecorationAttr foundDecorationAttr;
433 for (
auto decoration :
434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435 spirv::Decoration::AliasedPointer,
436 spirv::Decoration::RestrictPointer}) {
438 if (decAttr.getName() !=
439 getSymbolDecoration(stringifyDecoration(decoration)))
442 if (foundDecorationAttr)
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
452 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
453 spirv::Decoration::RelaxedPrecision))) {
458 if (foundDecorationAttr)
459 return emitError(unknownLoc,
"already found a decoration for function "
460 "argument with result <id> ")
464 context, spirv::Decoration::RelaxedPrecision);
468 if (!foundDecorationAttr)
469 return emitError(unknownLoc,
"unimplemented decoration support for "
470 "function argument with result <id> ")
474 foundDecorationAttr);
482 return emitError(unknownLoc,
"found function inside function");
486 if (operands.size() != 4) {
487 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
491 return emitError(unknownLoc,
"undefined result type from <id> ")
495 uint32_t fnID = operands[1];
496 if (funcMap.count(fnID)) {
497 return emitError(unknownLoc,
"duplicate function definition/declaration");
500 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
506 if (!fnType || !isa<FunctionType>(fnType)) {
507 return emitError(unknownLoc,
"unknown function type from <id> ")
510 auto functionType = cast<FunctionType>(fnType);
512 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
513 (functionType.getNumResults() == 1 &&
514 functionType.getResult(0) != resultType)) {
515 return emitError(unknownLoc,
"mismatch in function type ")
516 << functionType <<
" and return type " << resultType <<
" specified";
519 std::string fnName = getFunctionSymbol(fnID);
520 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
521 functionType, fnControl.value());
523 if (decorations.count(fnID)) {
524 for (
auto attr : decorations[fnID].getAttrs()) {
525 funcOp->setAttr(attr.getName(), attr.getValue());
528 curFunction = funcMap[fnID] = funcOp;
529 auto *entryBlock = funcOp.addEntryBlock();
532 <<
"//===-------------------------------------------===//\n";
533 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
534 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
535 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
536 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
541 argAttrs.resize(functionType.getNumInputs());
544 if (functionType.getNumInputs()) {
545 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
546 auto argType = functionType.getInput(i);
547 spirv::Opcode opcode = spirv::Opcode::OpNop;
549 if (failed(sliceInstruction(opcode, operands,
550 spirv::Opcode::OpFunctionParameter))) {
553 if (opcode != spirv::Opcode::OpFunctionParameter) {
556 "missing OpFunctionParameter instruction for argument ")
559 if (operands.size() != 2) {
562 "expected result type and result <id> for OpFunctionParameter");
564 auto argDefinedType =
getType(operands[0]);
565 if (!argDefinedType || argDefinedType != argType) {
567 "mismatch in argument type between function type "
569 << functionType <<
" and argument type definition "
570 << argDefinedType <<
" at argument " << i;
572 if (getValue(operands[1])) {
573 return emitError(unknownLoc,
"duplicate definition of result <id> ")
576 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
580 auto argValue = funcOp.getArgument(i);
581 valueMap[operands[1]] = argValue;
585 if (llvm::any_of(argAttrs, [](
Attribute attr) {
586 auto argAttr = cast<DictionaryAttr>(attr);
587 return !argAttr.empty();
594 auto linkageAttr = funcOp.getLinkageAttributes();
595 auto hasImportLinkage =
596 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
597 spirv::LinkageType::Import);
598 if (hasImportLinkage)
605 spirv::Opcode opcode = spirv::Opcode::OpNop;
613 if (failed(sliceInstruction(opcode, instOperands,
614 spirv::Opcode::OpFunctionEnd))) {
617 if (opcode == spirv::Opcode::OpFunctionEnd) {
618 return processFunctionEnd(instOperands);
620 if (opcode != spirv::Opcode::OpLabel) {
621 return emitError(unknownLoc,
"a basic block must start with OpLabel");
623 if (instOperands.size() != 1) {
624 return emitError(unknownLoc,
"OpLabel should only have result <id>");
626 blockMap[instOperands[0]] = entryBlock;
627 if (failed(processLabel(instOperands))) {
633 while (succeeded(sliceInstruction(opcode, instOperands,
634 spirv::Opcode::OpFunctionEnd)) &&
635 opcode != spirv::Opcode::OpFunctionEnd) {
636 if (failed(processInstruction(opcode, instOperands))) {
640 if (opcode != spirv::Opcode::OpFunctionEnd) {
644 return processFunctionEnd(instOperands);
650 if (!operands.empty()) {
651 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
657 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
662 curFunction = std::nullopt;
667 <<
"//===-------------------------------------------===//\n";
672 std::optional<std::pair<Attribute, Type>>
673 spirv::Deserializer::getConstant(uint32_t
id) {
674 auto constIt = constantMap.find(
id);
675 if (constIt == constantMap.end())
677 return constIt->getSecond();
680 std::optional<std::pair<Attribute, Type>>
681 spirv::Deserializer::getConstantCompositeReplicate(uint32_t
id) {
682 if (
auto it = constantCompositeReplicateMap.find(
id);
683 it != constantCompositeReplicateMap.end())
688 std::optional<spirv::SpecConstOperationMaterializationInfo>
689 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
690 auto constIt = specConstOperationMap.find(
id);
691 if (constIt == specConstOperationMap.end())
693 return constIt->getSecond();
696 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
697 auto funcName = nameMap.lookup(
id).str();
698 if (funcName.empty()) {
699 funcName =
"spirv_fn_" + std::to_string(
id);
704 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
705 auto constName = nameMap.lookup(
id).str();
706 if (constName.empty()) {
707 constName =
"spirv_spec_const_" + std::to_string(
id);
712 spirv::SpecConstantOp
713 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
714 TypedAttr defaultValue) {
715 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
716 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
718 if (decorations.count(resultID)) {
719 for (
auto attr : decorations[resultID].getAttrs())
720 op->setAttr(attr.getName(), attr.getValue());
722 specConstMap[resultID] = op;
728 unsigned wordIndex = 0;
729 if (operands.size() < 3) {
732 "OpVariable needs at least 3 operands, type, <id> and storage class");
736 auto type =
getType(operands[wordIndex]);
738 return emitError(unknownLoc,
"unknown result type <id> : ")
739 << operands[wordIndex];
741 auto ptrType = dyn_cast<spirv::PointerType>(type);
744 "expected a result type <id> to be a spirv.ptr, found : ")
750 auto variableID = operands[wordIndex];
751 auto variableName = nameMap.lookup(variableID).str();
752 if (variableName.empty()) {
753 variableName =
"spirv_var_" + std::to_string(variableID);
758 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
759 if (ptrType.getStorageClass() != storageClass) {
760 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
761 << type <<
" and that specified in OpVariable instruction : "
762 << stringifyStorageClass(storageClass);
769 if (wordIndex < operands.size()) {
772 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
774 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
776 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
779 return emitError(unknownLoc,
"unknown <id> ")
780 << operands[wordIndex] <<
"used as initializer";
785 if (wordIndex != operands.size()) {
787 "found more operands than expected when deserializing "
788 "OpVariable instruction, only ")
789 << wordIndex <<
" of " << operands.size() <<
" processed";
791 auto loc = createFileLineColLoc(opBuilder);
792 auto varOp = spirv::GlobalVariableOp::create(
794 opBuilder.getStringAttr(variableName), initializer);
797 if (decorations.count(variableID)) {
798 for (
auto attr : decorations[variableID].getAttrs())
799 varOp->setAttr(attr.getName(), attr.getValue());
801 globalVariableMap[variableID] = varOp;
805 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
806 auto constInfo = getConstant(
id);
810 return dyn_cast<IntegerAttr>(constInfo->first);
814 if (operands.size() < 2) {
815 return emitError(unknownLoc,
"OpName needs at least 2 operands");
817 if (!nameMap.lookup(operands[0]).empty()) {
818 return emitError(unknownLoc,
"duplicate name found for result <id> ")
821 unsigned wordIndex = 1;
823 if (wordIndex != operands.size()) {
825 "unexpected trailing words in OpName instruction");
827 nameMap[operands[0]] = name;
835 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
837 if (operands.empty()) {
838 return emitError(unknownLoc,
"type instruction with opcode ")
839 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
844 if (typeMap.count(operands[0])) {
845 return emitError(unknownLoc,
"duplicate definition for result <id> ")
850 case spirv::Opcode::OpTypeVoid:
851 if (operands.size() != 1)
852 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
853 typeMap[operands[0]] = opBuilder.getNoneType();
855 case spirv::Opcode::OpTypeBool:
856 if (operands.size() != 1)
857 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
858 typeMap[operands[0]] = opBuilder.getI1Type();
860 case spirv::Opcode::OpTypeInt: {
861 if (operands.size() != 3)
863 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
873 : IntegerType::SignednessSemantics::Signless;
876 case spirv::Opcode::OpTypeFloat: {
877 if (operands.size() != 2 && operands.size() != 3)
879 "OpTypeFloat expects either 2 operands (type, bitwidth) "
880 "or 3 operands (type, bitwidth, encoding), but got ")
882 uint32_t bitWidth = operands[1];
887 floatTy = opBuilder.getF16Type();
890 floatTy = opBuilder.getF32Type();
893 floatTy = opBuilder.getF64Type();
896 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
900 if (operands.size() == 3) {
901 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
902 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
906 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
907 << bitWidth <<
" (expected 16)";
908 floatTy = opBuilder.getBF16Type();
911 typeMap[operands[0]] = floatTy;
913 case spirv::Opcode::OpTypeVector: {
914 if (operands.size() != 3) {
917 "OpTypeVector must have element type and count parameters");
921 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
926 case spirv::Opcode::OpTypePointer: {
927 return processOpTypePointer(operands);
929 case spirv::Opcode::OpTypeArray:
930 return processArrayType(operands);
931 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
932 return processCooperativeMatrixTypeKHR(operands);
933 case spirv::Opcode::OpTypeFunction:
934 return processFunctionType(operands);
935 case spirv::Opcode::OpTypeImage:
936 return processImageType(operands);
937 case spirv::Opcode::OpTypeSampledImage:
938 return processSampledImageType(operands);
939 case spirv::Opcode::OpTypeRuntimeArray:
940 return processRuntimeArrayType(operands);
941 case spirv::Opcode::OpTypeStruct:
942 return processStructType(operands);
943 case spirv::Opcode::OpTypeMatrix:
944 return processMatrixType(operands);
945 case spirv::Opcode::OpTypeTensorARM:
946 return processTensorARMType(operands);
948 return emitError(unknownLoc,
"unhandled type instruction");
955 if (operands.size() != 3)
956 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
958 auto pointeeType =
getType(operands[2]);
960 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
963 uint32_t typePointerID = operands[0];
964 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
967 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
968 deferredStructIt != std::end(deferredStructTypesInfos);) {
969 for (
auto *unresolvedMemberIt =
970 std::begin(deferredStructIt->unresolvedMemberTypes);
971 unresolvedMemberIt !=
972 std::end(deferredStructIt->unresolvedMemberTypes);) {
973 if (unresolvedMemberIt->first == typePointerID) {
977 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
978 typeMap[typePointerID];
980 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
982 ++unresolvedMemberIt;
986 if (deferredStructIt->unresolvedMemberTypes.empty()) {
988 auto structType = deferredStructIt->deferredStructType;
990 assert(structType &&
"expected a spirv::StructType");
991 assert(structType.isIdentified() &&
"expected an indentified struct");
993 if (failed(structType.trySetBody(
994 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
995 deferredStructIt->memberDecorationsInfo,
996 deferredStructIt->structDecorationsInfo)))
999 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1010 if (operands.size() != 3) {
1012 "OpTypeArray must have element type and count parameters");
1017 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1023 auto countInfo = getConstant(operands[2]);
1025 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1026 << operands[2] <<
"can only come from normal constant right now";
1029 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1030 count = intVal.getValue().getZExtValue();
1032 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1033 "scalar integer constant instruction");
1037 elementTy, count, typeDecorations.lookup(operands[0]));
1043 assert(!operands.empty() &&
"No operands for processing function type");
1044 if (operands.size() == 1) {
1045 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1047 auto returnType =
getType(operands[1]);
1049 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1052 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1053 auto ty =
getType(operands[i]);
1055 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1057 argTypes.push_back(ty);
1060 if (!isVoidType(returnType)) {
1067 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1069 if (operands.size() != 6) {
1071 "OpTypeCooperativeMatrixKHR must have element type, "
1072 "scope, row and column parameters, and use");
1078 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1082 std::optional<spirv::Scope> scope =
1083 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1087 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1091 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1092 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1093 IntegerAttr useAttr = getConstantInt(operands[5]);
1096 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1097 "undefined constant <id> ")
1101 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1102 "references undefined constant <id> ")
1106 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1107 "undefined constant <id> ")
1110 unsigned rows = rowsAttr.getInt();
1111 unsigned columns = columnsAttr.getInt();
1113 std::optional<spirv::CooperativeMatrixUseKHR> use =
1114 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1118 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1122 typeMap[operands[0]] =
1129 if (operands.size() != 2) {
1130 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1135 "OpTypeRuntimeArray references undefined <id> ")
1139 memberType, typeDecorations.lookup(operands[0]));
1147 if (operands.empty()) {
1148 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1151 if (operands.size() == 1) {
1153 typeMap[operands[0]] =
1162 for (
auto op : llvm::drop_begin(operands, 1)) {
1164 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1166 if (!memberType && !typeForwardPtr)
1167 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1171 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1173 memberTypes.push_back(memberType);
1178 if (memberDecorationMap.count(operands[0])) {
1179 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1180 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1181 if (allMemberDecorations.count(memberIndex)) {
1182 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1184 if (memberDecoration.first == spirv::Decoration::Offset) {
1186 if (offsetInfo.empty()) {
1187 offsetInfo.resize(memberTypes.size());
1189 offsetInfo[memberIndex] = memberDecoration.second[0];
1192 if (!memberDecoration.second.empty()) {
1193 memberDecorationsInfo.emplace_back(
1194 memberIndex, memberDecoration.first,
1197 memberDecorationsInfo.emplace_back(
1198 memberIndex, memberDecoration.first,
UnitAttr::get(context));
1207 if (decorations.count(operands[0])) {
1210 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1211 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1212 assert(decoration.has_value());
1213 structDecorationsInfo.emplace_back(decoration.value(),
1214 decorationAttr.getValue());
1218 uint32_t structID = operands[0];
1219 std::string structIdentifier = nameMap.lookup(structID).str();
1221 if (structIdentifier.empty()) {
1222 assert(unresolvedMemberTypes.empty() &&
1223 "didn't expect unresolved member types");
1225 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1228 typeMap[structID] = structTy;
1230 if (!unresolvedMemberTypes.empty())
1231 deferredStructTypesInfos.push_back(
1232 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1233 memberDecorationsInfo, structDecorationsInfo});
1234 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1235 memberDecorationsInfo,
1236 structDecorationsInfo)))
1247 if (operands.size() != 3) {
1249 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1250 " (result_id, column_type, and column_count)");
1256 "OpTypeMatrix references undefined column type.")
1260 uint32_t colsCount = operands[2];
1267 unsigned size = operands.size();
1268 if (size < 2 || size > 4)
1269 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1270 "(result_id, element_type, (rank), (shape)) ")
1276 "OpTypeTensorARM references undefined element type ")
1284 IntegerAttr rankAttr = getConstantInt(operands[2]);
1286 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1287 "scalar integer constant instruction");
1288 unsigned rank = rankAttr.getValue().getZExtValue();
1295 std::optional<std::pair<Attribute, Type>> shapeInfo =
1296 getConstant(operands[3]);
1298 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1299 "constant instruction of type OpTypeArray");
1301 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1303 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1304 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1306 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1308 shape.push_back(dimIntAttr.getValue().getSExtValue());
1316 if (operands.size() != 2)
1318 "OpTypeForwardPointer instruction must have two operands");
1320 typeForwardPointerIDs.insert(operands[0]);
1330 if (operands.size() != 8)
1333 "OpTypeImage with non-eight operands are not supported yet");
1337 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1340 auto dim = spirv::symbolizeDim(operands[2]);
1342 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1345 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1347 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1350 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1352 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1355 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1357 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1359 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1360 if (!samplerUseInfo)
1361 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1364 auto format = spirv::symbolizeImageFormat(operands[7]);
1366 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1370 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1371 samplingInfo.value(), samplerUseInfo.value(), format.value());
1377 if (operands.size() != 2)
1378 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1383 "OpTypeSampledImage references undefined <id>: ")
1396 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1398 if (operands.size() < 2) {
1400 << opname <<
" must have type <id> and result <id>";
1402 if (operands.size() < 3) {
1404 << opname <<
" must have at least 1 more parameter";
1409 return emitError(unknownLoc,
"undefined result type from <id> ")
1413 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1414 if (bitwidth == 64) {
1415 if (operands.size() == 4) {
1419 << opname <<
" should have 2 parameters for 64-bit values";
1421 if (bitwidth <= 32) {
1422 if (operands.size() == 3) {
1428 <<
" should have 1 parameter for values with no more than 32 bits";
1430 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1434 auto resultID = operands[1];
1436 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1437 auto bitwidth = intType.getWidth();
1438 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1443 if (bitwidth == 64) {
1450 } words = {operands[2], operands[3]};
1451 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1452 }
else if (bitwidth <= 32) {
1453 value = APInt(bitwidth, operands[2],
true,
1457 auto attr = opBuilder.getIntegerAttr(intType, value);
1460 createSpecConstant(unknownLoc, resultID, attr);
1464 constantMap.try_emplace(resultID, attr, intType);
1470 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1471 auto bitwidth = floatType.getWidth();
1472 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1477 if (floatType.isF64()) {
1484 } words = {operands[2], operands[3]};
1485 value = APFloat(llvm::bit_cast<double>(words));
1486 }
else if (floatType.isF32()) {
1487 value = APFloat(llvm::bit_cast<float>(operands[2]));
1488 }
else if (floatType.isF16()) {
1489 APInt data(16, operands[2]);
1490 value = APFloat(APFloat::IEEEhalf(), data);
1491 }
else if (floatType.isBF16()) {
1492 APInt data(16, operands[2]);
1493 value = APFloat(APFloat::BFloat(), data);
1496 auto attr = opBuilder.getFloatAttr(floatType, value);
1498 createSpecConstant(unknownLoc, resultID, attr);
1502 constantMap.try_emplace(resultID, attr, floatType);
1508 return emitError(unknownLoc,
"OpConstant can only generate values of "
1509 "scalar integer or floating-point type");
1512 LogicalResult spirv::Deserializer::processConstantBool(
1514 if (operands.size() != 2) {
1516 << (isSpec ?
"Spec" :
"") <<
"Constant"
1517 << (isTrue ?
"True" :
"False")
1518 <<
" must have type <id> and result <id>";
1521 auto attr = opBuilder.getBoolAttr(isTrue);
1522 auto resultID = operands[1];
1524 createSpecConstant(unknownLoc, resultID, attr);
1528 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1536 if (operands.size() < 2) {
1538 "OpConstantComposite must have type <id> and result <id>");
1540 if (operands.size() < 3) {
1542 "OpConstantComposite must have at least 1 parameter");
1547 return emitError(unknownLoc,
"undefined result type from <id> ")
1552 elements.reserve(operands.size() - 2);
1553 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1554 auto elementInfo = getConstant(operands[i]);
1556 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1557 << operands[i] <<
" must come from a normal constant";
1559 elements.push_back(elementInfo->first);
1562 auto resultID = operands[1];
1563 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1566 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1567 for (
auto value : denseElemAttr.getValues<
Attribute>())
1568 flattenedElems.push_back(value);
1570 flattenedElems.push_back(element);
1574 constantMap.try_emplace(resultID, attr, tensorType);
1575 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1579 constantMap.try_emplace(resultID, attr, shapedType);
1580 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1581 auto attr = opBuilder.getArrayAttr(elements);
1582 constantMap.try_emplace(resultID, attr, resultType);
1584 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1591 LogicalResult spirv::Deserializer::processConstantCompositeReplicateEXT(
1593 if (operands.size() != 3) {
1596 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1602 return emitError(unknownLoc,
"undefined result type from <id> ")
1606 auto compositeType = dyn_cast<CompositeType>(resultType);
1607 if (!compositeType) {
1609 "result type from <id> is not a composite type")
1613 uint32_t resultID = operands[1];
1614 uint32_t constantID = operands[2];
1616 std::optional<std::pair<Attribute, Type>> constantInfo =
1617 getConstant(constantID);
1618 if (constantInfo.has_value()) {
1619 constantCompositeReplicateMap.try_emplace(
1620 resultID, constantInfo.value().first, resultType);
1624 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1625 getConstantCompositeReplicate(constantID);
1626 if (replicatedConstantCompositeInfo.has_value()) {
1627 constantCompositeReplicateMap.try_emplace(
1628 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1632 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1634 <<
" must come from a normal constant or a "
1635 "OpConstantCompositeReplicateEXT";
1640 if (operands.size() < 2) {
1643 "OpSpecConstantComposite must have type <id> and result <id>");
1645 if (operands.size() < 3) {
1647 "OpSpecConstantComposite must have at least 1 parameter");
1652 return emitError(unknownLoc,
"undefined result type from <id> ")
1656 auto resultID = operands[1];
1657 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1660 elements.reserve(operands.size() - 2);
1661 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1662 auto elementInfo = getSpecConstant(operands[i]);
1666 auto op = spirv::SpecConstantCompositeOp::create(
1668 opBuilder.getArrayAttr(elements));
1669 specConstCompositeMap[resultID] = op;
1674 LogicalResult spirv::Deserializer::processSpecConstantCompositeReplicateEXT(
1676 if (operands.size() != 3) {
1677 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1678 "3 operands but found ")
1684 return emitError(unknownLoc,
"undefined result type from <id> ")
1688 auto compositeType = dyn_cast<CompositeType>(resultType);
1689 if (!compositeType) {
1691 "result type from <id> is not a composite type")
1695 uint32_t resultID = operands[1];
1697 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1698 spirv::SpecConstantOp constituentSpecConstantOp =
1699 getSpecConstant(operands[2]);
1700 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1704 specConstCompositeReplicateMap[resultID] = op;
1711 if (operands.size() < 3)
1712 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1713 "result <id>, and operand opcode");
1715 uint32_t resultTypeID = operands[0];
1718 return emitError(unknownLoc,
"undefined result type from <id> ")
1721 uint32_t resultID = operands[1];
1722 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1723 auto emplaceResult = specConstOperationMap.try_emplace(
1725 SpecConstOperationMaterializationInfo{
1726 enclosedOpcode, resultTypeID,
1729 if (!emplaceResult.second)
1730 return emitError(unknownLoc,
"value with <id>: ")
1731 << resultID <<
" is probably defined before.";
1736 Value spirv::Deserializer::materializeSpecConstantOperation(
1737 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1753 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1754 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1757 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1758 enclosedOpResultTypeAndOperands.push_back(fakeID);
1759 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1760 enclosedOpOperands.end());
1767 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1774 auto loc = createFileLineColLoc(opBuilder);
1775 auto specConstOperationOp =
1776 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
1778 Region &body = specConstOperationOp.getBody();
1780 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
1787 opBuilder.setInsertionPointToEnd(&block);
1789 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
1790 return specConstOperationOp.getResult();
1795 if (operands.size() != 2) {
1797 "OpConstantNull must only have type <id> and result <id>");
1802 return emitError(unknownLoc,
"undefined result type from <id> ")
1806 auto resultID = operands[1];
1808 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1809 attr = opBuilder.getZeroAttr(resultType);
1810 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1811 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
1818 constantMap.try_emplace(resultID, attr, resultType);
1822 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1830 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1831 if (
auto *block = getBlock(
id)) {
1832 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1833 <<
" @ " << block <<
"\n");
1840 auto *block = curFunction->addBlock();
1841 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1842 <<
" @ " << block <<
"\n");
1843 return blockMap[id] = block;
1848 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1851 if (operands.size() != 1) {
1852 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1855 auto *target = getOrCreateBlock(operands[0]);
1856 auto loc = createFileLineColLoc(opBuilder);
1860 spirv::BranchOp::create(opBuilder, loc, target);
1870 "OpBranchConditional must appear inside a block");
1873 if (operands.size() != 3 && operands.size() != 5) {
1875 "OpBranchConditional must have condition, true label, "
1876 "false label, and optionally two branch weights");
1879 auto condition = getValue(operands[0]);
1880 auto *trueBlock = getOrCreateBlock(operands[1]);
1881 auto *falseBlock = getOrCreateBlock(operands[2]);
1883 std::optional<std::pair<uint32_t, uint32_t>> weights;
1884 if (operands.size() == 5) {
1885 weights = std::make_pair(operands[3], operands[4]);
1890 auto loc = createFileLineColLoc(opBuilder);
1891 spirv::BranchConditionalOp::create(
1892 opBuilder, loc, condition, trueBlock,
1902 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1905 if (operands.size() != 1) {
1906 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1909 auto labelID = operands[0];
1911 auto *block = getOrCreateBlock(labelID);
1912 LLVM_DEBUG(logger.startLine()
1913 <<
"[block] populating block " << block <<
"\n");
1915 assert(block->
empty() &&
"re-deserialize the same block!");
1917 opBuilder.setInsertionPointToStart(block);
1918 blockMap[labelID] = curBlock = block;
1926 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1929 if (operands.size() < 2) {
1932 "OpSelectionMerge must specify merge target and selection control");
1935 auto *mergeBlock = getOrCreateBlock(operands[0]);
1936 auto loc = createFileLineColLoc(opBuilder);
1937 auto selectionControl = operands[1];
1939 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1943 "a block cannot have more than one OpSelectionMerge instruction");
1952 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1955 if (operands.size() < 3) {
1956 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1957 "continue target and loop control");
1960 auto *mergeBlock = getOrCreateBlock(operands[0]);
1961 auto *continueBlock = getOrCreateBlock(operands[1]);
1962 auto loc = createFileLineColLoc(opBuilder);
1963 uint32_t loopControl = operands[2];
1966 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1970 "a block cannot have more than one OpLoopMerge instruction");
1978 return emitError(unknownLoc,
"OpPhi must appear in a block");
1981 if (operands.size() < 4) {
1982 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1983 "and variable-parent pairs");
1988 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1989 valueMap[operands[1]] = blockArg;
1990 LLVM_DEBUG(logger.startLine()
1991 <<
"[phi] created block argument " << blockArg
1992 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1996 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1997 uint32_t value = operands[i];
1998 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1999 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2000 blockPhiInfo[predecessorTargetPair].
push_back(value);
2001 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2002 <<
" with arg id = " << value <<
"\n");
2011 class ControlFlowStructurizer {
2014 ControlFlowStructurizer(
Location loc, uint32_t control,
2017 llvm::ScopedPrinter &logger)
2018 : location(loc), control(control), blockMergeInfo(mergeInfo),
2019 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2022 ControlFlowStructurizer(
Location loc, uint32_t control,
2025 : location(loc), control(control), blockMergeInfo(mergeInfo),
2026 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2036 LogicalResult structurize();
2041 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2044 spirv::LoopOp createLoopOp(uint32_t loopControl);
2047 void collectBlocksInConstruct();
2056 Block *continueBlock;
2062 llvm::ScopedPrinter &logger;
2068 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2071 OpBuilder builder(&mergeBlock->front());
2073 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2074 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2075 selectionOp.addMergeBlock(builder);
2080 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2083 OpBuilder builder(&mergeBlock->front());
2085 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2086 auto loopOp = spirv::LoopOp::create(builder, location, control);
2087 loopOp.addEntryAndMergeBlock(builder);
2092 void ControlFlowStructurizer::collectBlocksInConstruct() {
2093 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2096 constructBlocks.insert(headerBlock);
2100 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2101 for (
auto *successor : constructBlocks[i]->getSuccessors())
2102 if (successor != mergeBlock)
2103 constructBlocks.insert(successor);
2107 LogicalResult ControlFlowStructurizer::structurize() {
2109 bool isLoop = continueBlock !=
nullptr;
2111 if (
auto loopOp = createLoopOp(control))
2112 op = loopOp.getOperation();
2114 if (
auto selectionOp = createSelectionOp(control))
2115 op = selectionOp.getOperation();
2124 mapper.
map(mergeBlock, &body.
back());
2126 collectBlocksInConstruct();
2148 for (
auto *block : constructBlocks) {
2151 auto *newBlock = builder.createBlock(&body.
back());
2152 mapper.
map(block, newBlock);
2153 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2154 <<
" from block " << block <<
"\n");
2158 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2159 mapper.
map(blockArg, newArg);
2160 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2161 << blockArg <<
" to " << newArg <<
"\n");
2164 LLVM_DEBUG(logger.startLine()
2165 <<
"[cf] block " << block <<
" is a function entry block\n");
2168 for (
auto &op : *block)
2169 newBlock->push_back(op.
clone(mapper));
2173 auto remapOperands = [&](
Operation *op) {
2176 operand.set(mappedOp);
2179 succOp.set(mappedOp);
2181 for (
auto &block : body)
2182 block.walk(remapOperands);
2190 headerBlock->replaceAllUsesWith(mergeBlock);
2193 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2194 headerBlock->getParentOp()->
print(logger.getOStream());
2195 logger.startLine() <<
"\n";
2199 if (!mergeBlock->args_empty()) {
2200 return mergeBlock->getParentOp()->emitError(
2201 "OpPhi in loop merge block unsupported");
2208 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2213 if (!headerBlock->args_empty())
2214 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2218 builder.setInsertionPointToEnd(&body.front());
2219 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2247 body.back().addArgument(blockArg.
getType(), blockArg.
getLoc());
2248 valuesToYield.push_back(body.back().getArguments().back());
2249 outsideUses.push_back(blockArg);
2254 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2257 for (
auto *block : constructBlocks)
2258 block->dropAllReferences();
2263 for (
Block *block : constructBlocks) {
2268 outsideUses.push_back(result);
2272 if (!arg.use_empty()) {
2274 outsideUses.push_back(arg);
2279 assert(valuesToYield.size() == outsideUses.size());
2283 if (!valuesToYield.empty()) {
2284 LLVM_DEBUG(logger.startLine()
2285 <<
"[cf] yielding values from the selection / loop region\n");
2288 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2289 Operation *merge = llvm::getSingleElement(mergeOps);
2291 merge->setOperands(valuesToYield);
2299 builder.setInsertionPoint(&mergeBlock->front());
2304 newOp = spirv::LoopOp::create(builder, location,
2306 static_cast<spirv::LoopControl
>(control));
2308 newOp = spirv::SelectionOp::create(
2310 static_cast<spirv::SelectionControl
>(control));
2320 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2321 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2327 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2334 for (
auto *block : constructBlocks) {
2337 return op.
emitOpError(
"failed control flow structurization: value has "
2338 "uses outside of the "
2339 "enclosing selection/loop construct");
2341 if (!arg.use_empty())
2342 return emitError(arg.getLoc(),
"failed control flow structurization: "
2343 "block argument has uses outside of the "
2344 "enclosing selection/loop construct");
2348 for (
auto *block : constructBlocks) {
2389 auto it = blockMergeInfo.find(block);
2390 if (it != blockMergeInfo.end()) {
2396 return emitError(loc,
"failed control flow structurization: nested "
2397 "loop header block should be remapped!");
2399 Block *newContinue = it->second.continueBlock;
2403 return emitError(loc,
"failed control flow structurization: nested "
2404 "loop continue block should be remapped!");
2407 Block *newMerge = it->second.mergeBlock;
2409 newMerge = mappedTo;
2413 blockMergeInfo.
erase(it);
2414 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2421 if (block->walk(updateMergeInfo).wasInterrupted())
2429 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2430 <<
" to only contain a spirv.Branch op\n");
2434 builder.setInsertionPointToEnd(block);
2435 spirv::BranchOp::create(builder, location, mergeBlock);
2437 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2442 LLVM_DEBUG(logger.startLine()
2443 <<
"[cf] after structurizing construct with header block "
2444 << headerBlock <<
":\n"
2450 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2453 <<
"//----- [phi] start wiring up block arguments -----//\n";
2459 for (
const auto &info : blockPhiInfo) {
2460 Block *block = info.first.first;
2461 Block *target = info.first.second;
2462 const BlockPhiInfo &phiInfo = info.second;
2464 logger.startLine() <<
"[phi] block " << block <<
"\n";
2465 logger.startLine() <<
"[phi] before creating block argument:\n";
2467 logger.startLine() <<
"\n";
2473 opBuilder.setInsertionPoint(op);
2476 blockArgs.reserve(phiInfo.size());
2477 for (uint32_t valueId : phiInfo) {
2478 if (
Value value = getValue(valueId)) {
2479 blockArgs.push_back(value);
2480 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2481 <<
" id = " << valueId <<
"\n");
2483 return emitError(unknownLoc,
"OpPhi references undefined value!");
2487 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2489 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2490 branchOp.getTarget(), blockArgs);
2492 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2493 assert((branchCondOp.getTrueBlock() == target ||
2494 branchCondOp.getFalseBlock() == target) &&
2495 "expected target to be either the true or false target");
2496 if (target == branchCondOp.getTrueTarget())
2497 spirv::BranchConditionalOp::create(
2498 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2499 blockArgs, branchCondOp.getFalseBlockArguments(),
2500 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2501 branchCondOp.getFalseTarget());
2503 spirv::BranchConditionalOp::create(
2504 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2505 branchCondOp.getTrueBlockArguments(), blockArgs,
2506 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2507 branchCondOp.getFalseBlock());
2509 branchCondOp.erase();
2511 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2515 logger.startLine() <<
"[phi] after creating block argument:\n";
2517 logger.startLine() <<
"\n";
2520 blockPhiInfo.clear();
2525 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2530 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2533 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2535 auto &[block, mergeInfo] = *it;
2538 if (mergeInfo.continueBlock)
2547 if (!isa<spirv::BranchConditionalOp>(terminator))
2551 bool splitHeaderMergeBlock =
false;
2552 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2553 if (mergeInfo.mergeBlock == block)
2554 splitHeaderMergeBlock =
true;
2561 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2564 spirv::BranchOp::create(builder, block->
getParent()->
getLoc(), newBlock);
2568 blockMergeInfo.erase(block);
2569 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2576 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2577 if (!
options.enableControlFlowStructurization) {
2581 <<
"//----- [cf] skip structurizing control flow -----//\n";
2589 <<
"//----- [cf] start structurizing control flow -----//\n";
2594 logger.startLine() <<
"[cf] split conditional blocks\n";
2595 logger.startLine() <<
"\n";
2598 if (failed(splitConditionalBlocks())) {
2605 while (!blockMergeInfo.empty()) {
2606 Block *headerBlock = blockMergeInfo.
begin()->first;
2607 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2610 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2611 headerBlock->
print(logger.getOStream());
2612 logger.startLine() <<
"\n";
2615 auto *mergeBlock = mergeInfo.mergeBlock;
2616 assert(mergeBlock &&
"merge block cannot be nullptr");
2617 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2618 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2620 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2621 mergeBlock->print(logger.getOStream());
2622 logger.startLine() <<
"\n";
2625 auto *continueBlock = mergeInfo.continueBlock;
2626 LLVM_DEBUG(
if (continueBlock) {
2627 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2628 continueBlock->print(logger.getOStream());
2629 logger.startLine() <<
"\n";
2633 blockMergeInfo.erase(blockMergeInfo.begin());
2634 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2635 blockMergeInfo, headerBlock,
2636 mergeBlock, continueBlock
2642 if (failed(structurizer.structurize()))
2649 <<
"//--- [cf] completed structurizing control flow ---//\n";
2662 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2663 if (fileName.empty())
2664 fileName =
"<unknown>";
2676 if (operands.size() != 3)
2677 return emitError(unknownLoc,
"OpLine must have 3 operands");
2678 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2682 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2686 if (operands.size() < 2)
2687 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2689 if (!debugInfoMap.lookup(operands[0]).empty())
2691 "duplicate debug string found for result <id> ")
2694 unsigned wordIndex = 1;
2696 if (wordIndex != operands.size())
2698 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.