23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
79 if (failed(sliceInstruction(opcode, operands)))
82 if (failed(processInstruction(opcode, operands)))
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second,
false))) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
103 return std::move(module);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117 LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130 #define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
141 #undef MIN_VERSION_CASE
143 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
147 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158 if (operands.size() != 1)
159 return emitError(unknownLoc,
"OpCapability must have one parameter");
161 auto cap = spirv::symbolizeCapability(operands[0]);
163 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
165 capabilities.insert(*cap);
173 "OpExtension must have a literal string for the extension name");
176 unsigned wordIndex = 0;
178 if (wordIndex != words.size())
180 "unexpected trailing words in OpExtension instruction");
181 auto ext = spirv::symbolizeExtension(extName);
183 return emitError(unknownLoc,
"unknown extension: ") << extName;
185 extensions.insert(*ext);
191 if (words.size() < 2) {
193 "OpExtInstImport must have a result <id> and a literal "
194 "string for the extended instruction set name");
197 unsigned wordIndex = 1;
199 if (wordIndex != words.size()) {
201 "unexpected trailing words in OpExtInstImport");
206 void spirv::Deserializer::attachVCETriple() {
208 spirv::ModuleOp::getVCETripleAttrName(),
210 extensions.getArrayRef(), context));
215 if (operands.size() != 2)
216 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
219 module->getAddressingModelAttrName(),
220 opBuilder.getAttr<spirv::AddressingModelAttr>(
221 static_cast<spirv::AddressingModel
>(operands.front())));
223 (*module)->setAttr(module->getMemoryModelAttrName(),
224 opBuilder.getAttr<spirv::MemoryModelAttr>(
225 static_cast<spirv::MemoryModel
>(operands.back())));
230 template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
234 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
235 if (words.size() != 4) {
236 return emitError(loc,
"OpDecoration with ")
237 << decorationName <<
"needs a cache control integer literal and a "
238 << cacheControlKind <<
" cache control literal";
240 unsigned cacheLevel = words[2];
241 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
242 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
246 llvm::append_range(attrs, attrList);
247 attrs.push_back(value);
248 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
256 if (words.size() < 2) {
258 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
260 auto decorationName =
261 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
262 if (decorationName.empty()) {
263 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
265 auto symbol = getSymbolDecoration(decorationName);
266 switch (
static_cast<spirv::Decoration
>(words[1])) {
267 case spirv::Decoration::FPFastMathMode:
268 if (words.size() != 3) {
269 return emitError(unknownLoc,
"OpDecorate with ")
270 << decorationName <<
" needs a single integer literal";
272 decorations[words[0]].set(
274 static_cast<FPFastMathMode
>(words[2])));
276 case spirv::Decoration::FPRoundingMode:
277 if (words.size() != 3) {
278 return emitError(unknownLoc,
"OpDecorate with ")
279 << decorationName <<
" needs a single integer literal";
281 decorations[words[0]].set(
283 static_cast<FPRoundingMode
>(words[2])));
285 case spirv::Decoration::DescriptorSet:
286 case spirv::Decoration::Binding:
287 if (words.size() != 3) {
288 return emitError(unknownLoc,
"OpDecorate with ")
289 << decorationName <<
" needs a single integer literal";
291 decorations[words[0]].set(
292 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
294 case spirv::Decoration::BuiltIn:
295 if (words.size() != 3) {
296 return emitError(unknownLoc,
"OpDecorate with ")
297 << decorationName <<
" needs a single integer literal";
299 decorations[words[0]].set(
300 symbol, opBuilder.getStringAttr(
301 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
303 case spirv::Decoration::ArrayStride:
304 if (words.size() != 3) {
305 return emitError(unknownLoc,
"OpDecorate with ")
306 << decorationName <<
" needs a single integer literal";
308 typeDecorations[words[0]] = words[2];
310 case spirv::Decoration::LinkageAttributes: {
311 if (words.size() < 4) {
312 return emitError(unknownLoc,
"OpDecorate with ")
314 <<
" needs at least 1 string and 1 integer literal";
322 unsigned wordIndex = 2;
324 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
325 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
326 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
331 case spirv::Decoration::Aliased:
332 case spirv::Decoration::AliasedPointer:
333 case spirv::Decoration::Block:
334 case spirv::Decoration::BufferBlock:
335 case spirv::Decoration::Flat:
336 case spirv::Decoration::NonReadable:
337 case spirv::Decoration::NonWritable:
338 case spirv::Decoration::NoPerspective:
339 case spirv::Decoration::NoSignedWrap:
340 case spirv::Decoration::NoUnsignedWrap:
341 case spirv::Decoration::RelaxedPrecision:
342 case spirv::Decoration::Restrict:
343 case spirv::Decoration::RestrictPointer:
344 case spirv::Decoration::NoContraction:
345 case spirv::Decoration::Constant:
346 if (words.size() != 2) {
347 return emitError(unknownLoc,
"OpDecoration with ")
348 << decorationName <<
"needs a single target <id>";
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(unknownLoc,
"OpDecoration with ")
360 << decorationName <<
"needs a single integer literal";
362 decorations[words[0]].set(
363 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
365 case spirv::Decoration::CacheControlLoadINTEL: {
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
374 case spirv::Decoration::CacheControlStoreINTEL: {
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
384 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
392 if (words.size() < 3) {
394 "OpMemberDecorate must have at least 3 operands");
397 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
400 " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
404 if (words.size() > 3) {
405 decorationOperands = words.slice(3);
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
412 if (words.size() < 3) {
413 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
415 unsigned wordIndex = 2;
417 if (wordIndex != words.size()) {
419 "unexpected trailing words in OpMemberName instruction");
421 memberNameMap[words[0]][words[1]] = name;
425 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
427 if (!decorations.contains(argID)) {
432 spirv::DecorationAttr foundDecorationAttr;
434 for (
auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
439 if (decAttr.getName() !=
440 getSymbolDecoration(stringifyDecoration(decoration)))
443 if (foundDecorationAttr)
445 "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
453 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
454 spirv::Decoration::RelaxedPrecision))) {
459 if (foundDecorationAttr)
460 return emitError(unknownLoc,
"already found a decoration for function "
461 "argument with result <id> ")
465 context, spirv::Decoration::RelaxedPrecision);
469 if (!foundDecorationAttr)
470 return emitError(unknownLoc,
"unimplemented decoration support for "
471 "function argument with result <id> ")
475 foundDecorationAttr);
483 return emitError(unknownLoc,
"found function inside function");
487 if (operands.size() != 4) {
488 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
492 return emitError(unknownLoc,
"undefined result type from <id> ")
496 uint32_t fnID = operands[1];
497 if (funcMap.count(fnID)) {
498 return emitError(unknownLoc,
"duplicate function definition/declaration");
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
503 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
507 if (!fnType || !isa<FunctionType>(fnType)) {
508 return emitError(unknownLoc,
"unknown function type from <id> ")
511 auto functionType = cast<FunctionType>(fnType);
513 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(0) != resultType)) {
516 return emitError(unknownLoc,
"mismatch in function type ")
517 << functionType <<
" and return type " << resultType <<
" specified";
520 std::string fnName = getFunctionSymbol(fnID);
521 auto funcOp = opBuilder.create<spirv::FuncOp>(
522 unknownLoc, fnName, functionType, fnControl.value());
524 if (decorations.count(fnID)) {
525 for (
auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(attr.getName(), attr.getValue());
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
533 <<
"//===-------------------------------------------===//\n";
534 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
535 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
536 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
537 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
542 argAttrs.resize(functionType.getNumInputs());
545 if (functionType.getNumInputs()) {
546 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
550 if (failed(sliceInstruction(opcode, operands,
551 spirv::Opcode::OpFunctionParameter))) {
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
557 "missing OpFunctionParameter instruction for argument ")
560 if (operands.size() != 2) {
563 "expected result type and result <id> for OpFunctionParameter");
565 auto argDefinedType =
getType(operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
568 "mismatch in argument type between function type "
570 << functionType <<
" and argument type definition "
571 << argDefinedType <<
" at argument " << i;
573 if (getValue(operands[1])) {
574 return emitError(unknownLoc,
"duplicate definition of result <id> ")
577 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
581 auto argValue = funcOp.getArgument(i);
582 valueMap[operands[1]] = argValue;
586 if (llvm::any_of(argAttrs, [](
Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(attr);
588 return !argAttr.empty();
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
614 if (failed(sliceInstruction(opcode, instOperands,
615 spirv::Opcode::OpFunctionEnd))) {
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
619 return processFunctionEnd(instOperands);
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(unknownLoc,
"a basic block must start with OpLabel");
624 if (instOperands.size() != 1) {
625 return emitError(unknownLoc,
"OpLabel should only have result <id>");
627 blockMap[instOperands[0]] = entryBlock;
628 if (failed(processLabel(instOperands))) {
634 while (succeeded(sliceInstruction(opcode, instOperands,
635 spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
637 if (failed(processInstruction(opcode, instOperands))) {
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
645 return processFunctionEnd(instOperands);
651 if (!operands.empty()) {
652 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
658 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
663 curFunction = std::nullopt;
668 <<
"//===-------------------------------------------===//\n";
673 std::optional<std::pair<Attribute, Type>>
674 spirv::Deserializer::getConstant(uint32_t
id) {
675 auto constIt = constantMap.find(
id);
676 if (constIt == constantMap.end())
678 return constIt->getSecond();
681 std::optional<spirv::SpecConstOperationMaterializationInfo>
682 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
683 auto constIt = specConstOperationMap.find(
id);
684 if (constIt == specConstOperationMap.end())
686 return constIt->getSecond();
689 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
690 auto funcName = nameMap.lookup(
id).str();
691 if (funcName.empty()) {
692 funcName =
"spirv_fn_" + std::to_string(
id);
697 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
698 auto constName = nameMap.lookup(
id).str();
699 if (constName.empty()) {
700 constName =
"spirv_spec_const_" + std::to_string(
id);
705 spirv::SpecConstantOp
706 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
707 TypedAttr defaultValue) {
708 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
709 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
711 if (decorations.count(resultID)) {
712 for (
auto attr : decorations[resultID].getAttrs())
713 op->setAttr(attr.getName(), attr.getValue());
715 specConstMap[resultID] = op;
721 unsigned wordIndex = 0;
722 if (operands.size() < 3) {
725 "OpVariable needs at least 3 operands, type, <id> and storage class");
729 auto type =
getType(operands[wordIndex]);
731 return emitError(unknownLoc,
"unknown result type <id> : ")
732 << operands[wordIndex];
734 auto ptrType = dyn_cast<spirv::PointerType>(type);
737 "expected a result type <id> to be a spirv.ptr, found : ")
743 auto variableID = operands[wordIndex];
744 auto variableName = nameMap.lookup(variableID).str();
745 if (variableName.empty()) {
746 variableName =
"spirv_var_" + std::to_string(variableID);
751 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
752 if (ptrType.getStorageClass() != storageClass) {
753 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
754 << type <<
" and that specified in OpVariable instruction : "
755 << stringifyStorageClass(storageClass);
762 if (wordIndex < operands.size()) {
765 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
767 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
769 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
772 return emitError(unknownLoc,
"unknown <id> ")
773 << operands[wordIndex] <<
"used as initializer";
778 if (wordIndex != operands.size()) {
780 "found more operands than expected when deserializing "
781 "OpVariable instruction, only ")
782 << wordIndex <<
" of " << operands.size() <<
" processed";
784 auto loc = createFileLineColLoc(opBuilder);
785 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
786 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
790 if (decorations.count(variableID)) {
791 for (
auto attr : decorations[variableID].getAttrs())
792 varOp->setAttr(attr.getName(), attr.getValue());
794 globalVariableMap[variableID] = varOp;
798 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
799 auto constInfo = getConstant(
id);
803 return dyn_cast<IntegerAttr>(constInfo->first);
807 if (operands.size() < 2) {
808 return emitError(unknownLoc,
"OpName needs at least 2 operands");
810 if (!nameMap.lookup(operands[0]).empty()) {
811 return emitError(unknownLoc,
"duplicate name found for result <id> ")
814 unsigned wordIndex = 1;
816 if (wordIndex != operands.size()) {
818 "unexpected trailing words in OpName instruction");
820 nameMap[operands[0]] = name;
828 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
830 if (operands.empty()) {
831 return emitError(unknownLoc,
"type instruction with opcode ")
832 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
837 if (typeMap.count(operands[0])) {
838 return emitError(unknownLoc,
"duplicate definition for result <id> ")
843 case spirv::Opcode::OpTypeVoid:
844 if (operands.size() != 1)
845 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
846 typeMap[operands[0]] = opBuilder.getNoneType();
848 case spirv::Opcode::OpTypeBool:
849 if (operands.size() != 1)
850 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
851 typeMap[operands[0]] = opBuilder.getI1Type();
853 case spirv::Opcode::OpTypeInt: {
854 if (operands.size() != 3)
856 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
866 : IntegerType::SignednessSemantics::Signless;
869 case spirv::Opcode::OpTypeFloat: {
870 if (operands.size() != 2 && operands.size() != 3)
872 "OpTypeFloat expects either 2 operands (type, bitwidth) "
873 "or 3 operands (type, bitwidth, encoding), but got ")
875 uint32_t bitWidth = operands[1];
880 floatTy = opBuilder.getF16Type();
883 floatTy = opBuilder.getF32Type();
886 floatTy = opBuilder.getF64Type();
889 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
893 if (operands.size() == 3) {
894 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
895 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
899 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
900 << bitWidth <<
" (expected 16)";
901 floatTy = opBuilder.getBF16Type();
904 typeMap[operands[0]] = floatTy;
906 case spirv::Opcode::OpTypeVector: {
907 if (operands.size() != 3) {
910 "OpTypeVector must have element type and count parameters");
914 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
919 case spirv::Opcode::OpTypePointer: {
920 return processOpTypePointer(operands);
922 case spirv::Opcode::OpTypeArray:
923 return processArrayType(operands);
924 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
925 return processCooperativeMatrixTypeKHR(operands);
926 case spirv::Opcode::OpTypeFunction:
927 return processFunctionType(operands);
928 case spirv::Opcode::OpTypeImage:
929 return processImageType(operands);
930 case spirv::Opcode::OpTypeSampledImage:
931 return processSampledImageType(operands);
932 case spirv::Opcode::OpTypeRuntimeArray:
933 return processRuntimeArrayType(operands);
934 case spirv::Opcode::OpTypeStruct:
935 return processStructType(operands);
936 case spirv::Opcode::OpTypeMatrix:
937 return processMatrixType(operands);
938 case spirv::Opcode::OpTypeTensorARM:
939 return processTensorARMType(operands);
941 return emitError(unknownLoc,
"unhandled type instruction");
948 if (operands.size() != 3)
949 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
951 auto pointeeType =
getType(operands[2]);
953 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
956 uint32_t typePointerID = operands[0];
957 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
960 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
961 deferredStructIt != std::end(deferredStructTypesInfos);) {
962 for (
auto *unresolvedMemberIt =
963 std::begin(deferredStructIt->unresolvedMemberTypes);
964 unresolvedMemberIt !=
965 std::end(deferredStructIt->unresolvedMemberTypes);) {
966 if (unresolvedMemberIt->first == typePointerID) {
970 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
971 typeMap[typePointerID];
973 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
975 ++unresolvedMemberIt;
979 if (deferredStructIt->unresolvedMemberTypes.empty()) {
981 auto structType = deferredStructIt->deferredStructType;
983 assert(structType &&
"expected a spirv::StructType");
984 assert(structType.isIdentified() &&
"expected an indentified struct");
986 if (failed(structType.trySetBody(
987 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
988 deferredStructIt->memberDecorationsInfo)))
991 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1002 if (operands.size() != 3) {
1004 "OpTypeArray must have element type and count parameters");
1009 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1015 auto countInfo = getConstant(operands[2]);
1017 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1018 << operands[2] <<
"can only come from normal constant right now";
1021 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1022 count = intVal.getValue().getZExtValue();
1024 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1025 "scalar integer constant instruction");
1029 elementTy, count, typeDecorations.lookup(operands[0]));
1035 assert(!operands.empty() &&
"No operands for processing function type");
1036 if (operands.size() == 1) {
1037 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1039 auto returnType =
getType(operands[1]);
1041 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1044 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1045 auto ty =
getType(operands[i]);
1047 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1049 argTypes.push_back(ty);
1052 if (!isVoidType(returnType)) {
1059 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1061 if (operands.size() != 6) {
1063 "OpTypeCooperativeMatrixKHR must have element type, "
1064 "scope, row and column parameters, and use");
1070 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1074 std::optional<spirv::Scope> scope =
1075 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1079 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1083 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1084 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1085 IntegerAttr useAttr = getConstantInt(operands[5]);
1088 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1089 "undefined constant <id> ")
1093 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1094 "references undefined constant <id> ")
1098 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1099 "undefined constant <id> ")
1102 unsigned rows = rowsAttr.getInt();
1103 unsigned columns = columnsAttr.getInt();
1105 std::optional<spirv::CooperativeMatrixUseKHR> use =
1106 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1110 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1114 typeMap[operands[0]] =
1121 if (operands.size() != 2) {
1122 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1127 "OpTypeRuntimeArray references undefined <id> ")
1131 memberType, typeDecorations.lookup(operands[0]));
1139 if (operands.empty()) {
1140 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1143 if (operands.size() == 1) {
1145 typeMap[operands[0]] =
1154 for (
auto op : llvm::drop_begin(operands, 1)) {
1156 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1158 if (!memberType && !typeForwardPtr)
1159 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1163 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1165 memberTypes.push_back(memberType);
1170 if (memberDecorationMap.count(operands[0])) {
1171 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1172 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1173 if (allMemberDecorations.count(memberIndex)) {
1174 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1176 if (memberDecoration.first == spirv::Decoration::Offset) {
1178 if (offsetInfo.empty()) {
1179 offsetInfo.resize(memberTypes.size());
1181 offsetInfo[memberIndex] = memberDecoration.second[0];
1183 if (!memberDecoration.second.empty()) {
1184 memberDecorationsInfo.emplace_back(memberIndex, 1,
1185 memberDecoration.first,
1186 memberDecoration.second[0]);
1188 memberDecorationsInfo.emplace_back(memberIndex, 0,
1189 memberDecoration.first, 0);
1197 uint32_t structID = operands[0];
1198 std::string structIdentifier = nameMap.lookup(structID).str();
1200 if (structIdentifier.empty()) {
1201 assert(unresolvedMemberTypes.empty() &&
1202 "didn't expect unresolved member types");
1207 typeMap[structID] = structTy;
1209 if (!unresolvedMemberTypes.empty())
1210 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1211 memberTypes, offsetInfo,
1212 memberDecorationsInfo});
1213 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1214 memberDecorationsInfo)))
1225 if (operands.size() != 3) {
1227 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1228 " (result_id, column_type, and column_count)");
1234 "OpTypeMatrix references undefined column type.")
1238 uint32_t colsCount = operands[2];
1245 unsigned size = operands.size();
1246 if (size < 2 || size > 4)
1247 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1248 "(result_id, element_type, (rank), (shape)) ")
1254 "OpTypeTensorARM references undefined element type ")
1262 IntegerAttr rankAttr = getConstantInt(operands[2]);
1264 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1265 "scalar integer constant instruction");
1266 unsigned rank = rankAttr.getValue().getZExtValue();
1273 std::optional<std::pair<Attribute, Type>> shapeInfo =
1274 getConstant(operands[3]);
1276 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1277 "constant instruction of type OpTypeArray");
1279 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1281 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1282 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1284 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1286 shape.push_back(dimIntAttr.getValue().getSExtValue());
1294 if (operands.size() != 2)
1296 "OpTypeForwardPointer instruction must have two operands");
1298 typeForwardPointerIDs.insert(operands[0]);
1308 if (operands.size() != 8)
1311 "OpTypeImage with non-eight operands are not supported yet");
1315 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1318 auto dim = spirv::symbolizeDim(operands[2]);
1320 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1323 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1325 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1328 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1330 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1333 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1335 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1337 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1338 if (!samplerUseInfo)
1339 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1342 auto format = spirv::symbolizeImageFormat(operands[7]);
1344 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1348 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1349 samplingInfo.value(), samplerUseInfo.value(), format.value());
1355 if (operands.size() != 2)
1356 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1361 "OpTypeSampledImage references undefined <id>: ")
1374 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1376 if (operands.size() < 2) {
1378 << opname <<
" must have type <id> and result <id>";
1380 if (operands.size() < 3) {
1382 << opname <<
" must have at least 1 more parameter";
1387 return emitError(unknownLoc,
"undefined result type from <id> ")
1391 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1392 if (bitwidth == 64) {
1393 if (operands.size() == 4) {
1397 << opname <<
" should have 2 parameters for 64-bit values";
1399 if (bitwidth <= 32) {
1400 if (operands.size() == 3) {
1406 <<
" should have 1 parameter for values with no more than 32 bits";
1408 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1412 auto resultID = operands[1];
1414 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1415 auto bitwidth = intType.getWidth();
1416 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1421 if (bitwidth == 64) {
1428 } words = {operands[2], operands[3]};
1429 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1430 }
else if (bitwidth <= 32) {
1431 value = APInt(bitwidth, operands[2],
true,
1435 auto attr = opBuilder.getIntegerAttr(intType, value);
1438 createSpecConstant(unknownLoc, resultID, attr);
1442 constantMap.try_emplace(resultID, attr, intType);
1448 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1449 auto bitwidth = floatType.getWidth();
1450 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1455 if (floatType.isF64()) {
1462 } words = {operands[2], operands[3]};
1463 value = APFloat(llvm::bit_cast<double>(words));
1464 }
else if (floatType.isF32()) {
1465 value = APFloat(llvm::bit_cast<float>(operands[2]));
1466 }
else if (floatType.isF16()) {
1467 APInt data(16, operands[2]);
1468 value = APFloat(APFloat::IEEEhalf(), data);
1469 }
else if (floatType.isBF16()) {
1470 APInt data(16, operands[2]);
1471 value = APFloat(APFloat::BFloat(), data);
1474 auto attr = opBuilder.getFloatAttr(floatType, value);
1476 createSpecConstant(unknownLoc, resultID, attr);
1480 constantMap.try_emplace(resultID, attr, floatType);
1486 return emitError(unknownLoc,
"OpConstant can only generate values of "
1487 "scalar integer or floating-point type");
1490 LogicalResult spirv::Deserializer::processConstantBool(
1492 if (operands.size() != 2) {
1494 << (isSpec ?
"Spec" :
"") <<
"Constant"
1495 << (isTrue ?
"True" :
"False")
1496 <<
" must have type <id> and result <id>";
1499 auto attr = opBuilder.getBoolAttr(isTrue);
1500 auto resultID = operands[1];
1502 createSpecConstant(unknownLoc, resultID, attr);
1506 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1514 if (operands.size() < 2) {
1516 "OpConstantComposite must have type <id> and result <id>");
1518 if (operands.size() < 3) {
1520 "OpConstantComposite must have at least 1 parameter");
1525 return emitError(unknownLoc,
"undefined result type from <id> ")
1530 elements.reserve(operands.size() - 2);
1531 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1532 auto elementInfo = getConstant(operands[i]);
1534 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1535 << operands[i] <<
" must come from a normal constant";
1537 elements.push_back(elementInfo->first);
1540 auto resultID = operands[1];
1541 if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1545 constantMap.try_emplace(resultID, attr, shapedType);
1546 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1547 auto attr = opBuilder.getArrayAttr(elements);
1548 constantMap.try_emplace(resultID, attr, resultType);
1550 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1559 if (operands.size() < 2) {
1561 "OpConstantComposite must have type <id> and result <id>");
1563 if (operands.size() < 3) {
1565 "OpConstantComposite must have at least 1 parameter");
1570 return emitError(unknownLoc,
"undefined result type from <id> ")
1574 auto resultID = operands[1];
1575 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1578 elements.reserve(operands.size() - 2);
1579 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1580 auto elementInfo = getSpecConstant(operands[i]);
1584 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1586 opBuilder.getArrayAttr(elements));
1587 specConstCompositeMap[resultID] = op;
1594 if (operands.size() < 3)
1595 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1596 "result <id>, and operand opcode");
1598 uint32_t resultTypeID = operands[0];
1601 return emitError(unknownLoc,
"undefined result type from <id> ")
1604 uint32_t resultID = operands[1];
1605 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1606 auto emplaceResult = specConstOperationMap.try_emplace(
1608 SpecConstOperationMaterializationInfo{
1609 enclosedOpcode, resultTypeID,
1612 if (!emplaceResult.second)
1613 return emitError(unknownLoc,
"value with <id>: ")
1614 << resultID <<
" is probably defined before.";
1619 Value spirv::Deserializer::materializeSpecConstantOperation(
1620 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1636 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1637 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1640 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1641 enclosedOpResultTypeAndOperands.push_back(fakeID);
1642 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1643 enclosedOpOperands.end());
1650 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1657 auto loc = createFileLineColLoc(opBuilder);
1658 auto specConstOperationOp =
1659 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1661 Region &body = specConstOperationOp.getBody();
1663 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1670 opBuilder.setInsertionPointToEnd(&block);
1672 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1673 return specConstOperationOp.getResult();
1678 if (operands.size() != 2) {
1680 "OpConstantNull must have type <id> and result <id>");
1685 return emitError(unknownLoc,
"undefined result type from <id> ")
1689 auto resultID = operands[1];
1690 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1691 auto attr = opBuilder.getZeroAttr(resultType);
1694 constantMap.try_emplace(resultID, attr, resultType);
1698 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1706 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1707 if (
auto *block = getBlock(
id)) {
1708 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1709 <<
" @ " << block <<
"\n");
1716 auto *block = curFunction->addBlock();
1717 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1718 <<
" @ " << block <<
"\n");
1719 return blockMap[id] = block;
1724 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1727 if (operands.size() != 1) {
1728 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1731 auto *target = getOrCreateBlock(operands[0]);
1732 auto loc = createFileLineColLoc(opBuilder);
1736 opBuilder.create<spirv::BranchOp>(loc, target);
1746 "OpBranchConditional must appear inside a block");
1749 if (operands.size() != 3 && operands.size() != 5) {
1751 "OpBranchConditional must have condition, true label, "
1752 "false label, and optionally two branch weights");
1755 auto condition = getValue(operands[0]);
1756 auto *trueBlock = getOrCreateBlock(operands[1]);
1757 auto *falseBlock = getOrCreateBlock(operands[2]);
1759 std::optional<std::pair<uint32_t, uint32_t>> weights;
1760 if (operands.size() == 5) {
1761 weights = std::make_pair(operands[3], operands[4]);
1766 auto loc = createFileLineColLoc(opBuilder);
1767 opBuilder.create<spirv::BranchConditionalOp>(
1768 loc, condition, trueBlock,
1778 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1781 if (operands.size() != 1) {
1782 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1785 auto labelID = operands[0];
1787 auto *block = getOrCreateBlock(labelID);
1788 LLVM_DEBUG(logger.startLine()
1789 <<
"[block] populating block " << block <<
"\n");
1791 assert(block->
empty() &&
"re-deserialize the same block!");
1793 opBuilder.setInsertionPointToStart(block);
1794 blockMap[labelID] = curBlock = block;
1802 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1805 if (operands.size() < 2) {
1808 "OpSelectionMerge must specify merge target and selection control");
1811 auto *mergeBlock = getOrCreateBlock(operands[0]);
1812 auto loc = createFileLineColLoc(opBuilder);
1813 auto selectionControl = operands[1];
1815 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1819 "a block cannot have more than one OpSelectionMerge instruction");
1828 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1831 if (operands.size() < 3) {
1832 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1833 "continue target and loop control");
1836 auto *mergeBlock = getOrCreateBlock(operands[0]);
1837 auto *continueBlock = getOrCreateBlock(operands[1]);
1838 auto loc = createFileLineColLoc(opBuilder);
1839 uint32_t loopControl = operands[2];
1842 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1846 "a block cannot have more than one OpLoopMerge instruction");
1854 return emitError(unknownLoc,
"OpPhi must appear in a block");
1857 if (operands.size() < 4) {
1858 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1859 "and variable-parent pairs");
1864 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1865 valueMap[operands[1]] = blockArg;
1866 LLVM_DEBUG(logger.startLine()
1867 <<
"[phi] created block argument " << blockArg
1868 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1872 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1873 uint32_t value = operands[i];
1874 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1875 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1876 blockPhiInfo[predecessorTargetPair].
push_back(value);
1877 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1878 <<
" with arg id = " << value <<
"\n");
1887 class ControlFlowStructurizer {
1890 ControlFlowStructurizer(
Location loc, uint32_t control,
1893 llvm::ScopedPrinter &logger)
1894 : location(loc), control(control), blockMergeInfo(mergeInfo),
1895 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1898 ControlFlowStructurizer(
Location loc, uint32_t control,
1901 : location(loc), control(control), blockMergeInfo(mergeInfo),
1902 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1912 LogicalResult structurize();
1917 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1920 spirv::LoopOp createLoopOp(uint32_t loopControl);
1923 void collectBlocksInConstruct();
1932 Block *continueBlock;
1938 llvm::ScopedPrinter &logger;
1944 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1947 OpBuilder builder(&mergeBlock->front());
1949 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1950 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1951 selectionOp.addMergeBlock(builder);
1956 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1959 OpBuilder builder(&mergeBlock->front());
1961 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1962 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1963 loopOp.addEntryAndMergeBlock(builder);
1968 void ControlFlowStructurizer::collectBlocksInConstruct() {
1969 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1972 constructBlocks.insert(headerBlock);
1976 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1977 for (
auto *successor : constructBlocks[i]->getSuccessors())
1978 if (successor != mergeBlock)
1979 constructBlocks.insert(successor);
1983 LogicalResult ControlFlowStructurizer::structurize() {
1985 bool isLoop = continueBlock !=
nullptr;
1987 if (
auto loopOp = createLoopOp(control))
1988 op = loopOp.getOperation();
1990 if (
auto selectionOp = createSelectionOp(control))
1991 op = selectionOp.getOperation();
2000 mapper.
map(mergeBlock, &body.
back());
2002 collectBlocksInConstruct();
2024 for (
auto *block : constructBlocks) {
2027 auto *newBlock = builder.createBlock(&body.
back());
2028 mapper.
map(block, newBlock);
2029 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2030 <<
" from block " << block <<
"\n");
2034 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2035 mapper.
map(blockArg, newArg);
2036 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2037 << blockArg <<
" to " << newArg <<
"\n");
2040 LLVM_DEBUG(logger.startLine()
2041 <<
"[cf] block " << block <<
" is a function entry block\n");
2044 for (
auto &op : *block)
2045 newBlock->push_back(op.
clone(mapper));
2049 auto remapOperands = [&](
Operation *op) {
2052 operand.set(mappedOp);
2055 succOp.set(mappedOp);
2057 for (
auto &block : body)
2058 block.walk(remapOperands);
2066 headerBlock->replaceAllUsesWith(mergeBlock);
2069 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2070 headerBlock->getParentOp()->
print(logger.getOStream());
2071 logger.startLine() <<
"\n";
2075 if (!mergeBlock->args_empty()) {
2076 return mergeBlock->getParentOp()->emitError(
2077 "OpPhi in loop merge block unsupported");
2084 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
2089 if (!headerBlock->args_empty())
2090 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2094 builder.setInsertionPointToEnd(&body.front());
2095 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
2123 body.back().addArgument(blockArg.
getType(), blockArg.
getLoc());
2124 valuesToYield.push_back(body.back().getArguments().back());
2125 outsideUses.push_back(blockArg);
2130 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2133 for (
auto *block : constructBlocks)
2134 block->dropAllReferences();
2139 for (
Block *block : constructBlocks) {
2144 outsideUses.push_back(result);
2148 if (!arg.use_empty()) {
2150 outsideUses.push_back(arg);
2155 assert(valuesToYield.size() == outsideUses.size());
2159 if (!valuesToYield.empty()) {
2160 LLVM_DEBUG(logger.startLine()
2161 <<
"[cf] yielding values from the selection / loop region\n");
2164 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2165 Operation *merge = llvm::getSingleElement(mergeOps);
2167 merge->setOperands(valuesToYield);
2175 builder.setInsertionPoint(&mergeBlock->front());
2180 newOp = builder.
create<spirv::LoopOp>(
2182 static_cast<spirv::LoopControl
>(control));
2184 newOp = builder.
create<spirv::SelectionOp>(
2186 static_cast<spirv::SelectionControl
>(control));
2196 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2197 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2203 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2210 for (
auto *block : constructBlocks) {
2213 return op.
emitOpError(
"failed control flow structurization: value has "
2214 "uses outside of the "
2215 "enclosing selection/loop construct");
2217 if (!arg.use_empty())
2218 return emitError(arg.getLoc(),
"failed control flow structurization: "
2219 "block argument has uses outside of the "
2220 "enclosing selection/loop construct");
2224 for (
auto *block : constructBlocks) {
2265 auto it = blockMergeInfo.find(block);
2266 if (it != blockMergeInfo.end()) {
2272 return emitError(loc,
"failed control flow structurization: nested "
2273 "loop header block should be remapped!");
2275 Block *newContinue = it->second.continueBlock;
2279 return emitError(loc,
"failed control flow structurization: nested "
2280 "loop continue block should be remapped!");
2283 Block *newMerge = it->second.mergeBlock;
2285 newMerge = mappedTo;
2289 blockMergeInfo.
erase(it);
2290 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2297 if (block->walk(updateMergeInfo).wasInterrupted())
2305 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2306 <<
" to only contain a spirv.Branch op\n");
2310 builder.setInsertionPointToEnd(block);
2311 builder.create<spirv::BranchOp>(location, mergeBlock);
2313 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2318 LLVM_DEBUG(logger.startLine()
2319 <<
"[cf] after structurizing construct with header block "
2320 << headerBlock <<
":\n"
2326 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2329 <<
"//----- [phi] start wiring up block arguments -----//\n";
2335 for (
const auto &info : blockPhiInfo) {
2336 Block *block = info.first.first;
2337 Block *target = info.first.second;
2338 const BlockPhiInfo &phiInfo = info.second;
2340 logger.startLine() <<
"[phi] block " << block <<
"\n";
2341 logger.startLine() <<
"[phi] before creating block argument:\n";
2343 logger.startLine() <<
"\n";
2349 opBuilder.setInsertionPoint(op);
2352 blockArgs.reserve(phiInfo.size());
2353 for (uint32_t valueId : phiInfo) {
2354 if (
Value value = getValue(valueId)) {
2355 blockArgs.push_back(value);
2356 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2357 <<
" id = " << valueId <<
"\n");
2359 return emitError(unknownLoc,
"OpPhi references undefined value!");
2363 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2365 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2368 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2369 assert((branchCondOp.getTrueBlock() == target ||
2370 branchCondOp.getFalseBlock() == target) &&
2371 "expected target to be either the true or false target");
2372 if (target == branchCondOp.getTrueTarget())
2373 opBuilder.create<spirv::BranchConditionalOp>(
2374 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2375 branchCondOp.getFalseBlockArguments(),
2376 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2377 branchCondOp.getFalseTarget());
2379 opBuilder.create<spirv::BranchConditionalOp>(
2380 branchCondOp.getLoc(), branchCondOp.getCondition(),
2381 branchCondOp.getTrueBlockArguments(), blockArgs,
2382 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2383 branchCondOp.getFalseBlock());
2385 branchCondOp.erase();
2387 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2391 logger.startLine() <<
"[phi] after creating block argument:\n";
2393 logger.startLine() <<
"\n";
2396 blockPhiInfo.clear();
2401 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2406 LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2409 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2411 auto &[block, mergeInfo] = *it;
2414 if (mergeInfo.continueBlock)
2423 if (!isa<spirv::BranchConditionalOp>(terminator))
2427 bool splitHeaderMergeBlock =
false;
2428 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2429 if (mergeInfo.mergeBlock == block)
2430 splitHeaderMergeBlock =
true;
2437 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2440 builder.create<spirv::BranchOp>(block->
getParent()->
getLoc(), newBlock);
2444 blockMergeInfo.erase(block);
2445 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2452 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2453 if (!
options.enableControlFlowStructurization) {
2457 <<
"//----- [cf] skip structurizing control flow -----//\n";
2465 <<
"//----- [cf] start structurizing control flow -----//\n";
2470 logger.startLine() <<
"[cf] split conditional blocks\n";
2471 logger.startLine() <<
"\n";
2474 if (failed(splitConditionalBlocks())) {
2481 while (!blockMergeInfo.empty()) {
2482 Block *headerBlock = blockMergeInfo.
begin()->first;
2483 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2486 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2487 headerBlock->
print(logger.getOStream());
2488 logger.startLine() <<
"\n";
2491 auto *mergeBlock = mergeInfo.mergeBlock;
2492 assert(mergeBlock &&
"merge block cannot be nullptr");
2493 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2494 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2496 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2497 mergeBlock->print(logger.getOStream());
2498 logger.startLine() <<
"\n";
2501 auto *continueBlock = mergeInfo.continueBlock;
2502 LLVM_DEBUG(
if (continueBlock) {
2503 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2504 continueBlock->print(logger.getOStream());
2505 logger.startLine() <<
"\n";
2509 blockMergeInfo.erase(blockMergeInfo.begin());
2510 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2511 blockMergeInfo, headerBlock,
2512 mergeBlock, continueBlock
2518 if (failed(structurizer.structurize()))
2525 <<
"//--- [cf] completed structurizing control flow ---//\n";
2538 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2539 if (fileName.empty())
2540 fileName =
"<unknown>";
2552 if (operands.size() != 3)
2553 return emitError(unknownLoc,
"OpLine must have 3 operands");
2554 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2558 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2562 if (operands.size() < 2)
2563 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2565 if (!debugInfoMap.lookup(operands[0]).empty())
2567 "duplicate debug string found for result <id> ")
2570 unsigned wordIndex = 1;
2572 if (wordIndex != operands.size())
2574 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool mightHaveTerminator()
Check whether this block might have a terminator.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
Location getLoc()
Return a location for this region.
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context, const DeserializationOptions &options)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.