23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/Sequence.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/SaveAndRestore.h"
30 #include "llvm/Support/raw_ostream.h"
35 #define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
53 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
54 module(createModuleOp()), opBuilder(module->getRegion())
66 <<
"//+++---------- start deserialization ----------+++//\n";
69 if (failed(processHeader()))
72 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 auto binarySize = binary.size();
75 while (curOffset < binarySize) {
78 if (failed(sliceInstruction(opcode, operands)))
81 if (failed(processInstruction(opcode, operands)))
85 assert(curOffset == binarySize &&
86 "deserializer should never index beyond the binary end");
88 for (
auto &deferred : deferredInstructions) {
89 if (failed(processInstruction(deferred.first, deferred.second,
false))) {
96 LLVM_DEBUG(logger.startLine()
97 <<
"//+++-------- completed deserialization --------+++//\n");
102 return std::move(module);
111 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
112 spirv::ModuleOp::build(builder, state);
116 LogicalResult spirv::Deserializer::processHeader() {
119 "SPIR-V binary module must have a 5-word header");
122 return emitError(unknownLoc,
"incorrect magic number");
125 uint32_t majorVersion = (binary[1] << 8) >> 24;
126 uint32_t minorVersion = (binary[1] << 16) >> 24;
127 if (majorVersion == 1) {
128 switch (minorVersion) {
129 #define MIN_VERSION_CASE(v) \
131 version = spirv::Version::V_1_##v; \
140 #undef MIN_VERSION_CASE
142 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
146 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
157 if (operands.size() != 1)
158 return emitError(unknownLoc,
"OpMemoryModel must have one parameter");
160 auto cap = spirv::symbolizeCapability(operands[0]);
162 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
164 capabilities.insert(*cap);
172 "OpExtension must have a literal string for the extension name");
175 unsigned wordIndex = 0;
177 if (wordIndex != words.size())
179 "unexpected trailing words in OpExtension instruction");
180 auto ext = spirv::symbolizeExtension(extName);
182 return emitError(unknownLoc,
"unknown extension: ") << extName;
184 extensions.insert(*ext);
190 if (words.size() < 2) {
192 "OpExtInstImport must have a result <id> and a literal "
193 "string for the extended instruction set name");
196 unsigned wordIndex = 1;
198 if (wordIndex != words.size()) {
200 "unexpected trailing words in OpExtInstImport");
205 void spirv::Deserializer::attachVCETriple() {
207 spirv::ModuleOp::getVCETripleAttrName(),
209 extensions.getArrayRef(), context));
214 if (operands.size() != 2)
215 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
218 module->getAddressingModelAttrName(),
219 opBuilder.getAttr<spirv::AddressingModelAttr>(
220 static_cast<spirv::AddressingModel
>(operands.front())));
222 (*module)->setAttr(module->getMemoryModelAttrName(),
223 opBuilder.getAttr<spirv::MemoryModelAttr>(
224 static_cast<spirv::MemoryModel
>(operands.back())));
229 template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
233 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
234 if (words.size() != 4) {
235 return emitError(loc,
"OpDecoration with ")
236 << decorationName <<
"needs a cache control integer literal and a "
237 << cacheControlKind <<
" cache control literal";
239 unsigned cacheLevel = words[2];
240 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
241 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
244 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
245 llvm::append_range(attrs, attrList);
246 attrs.push_back(value);
247 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
255 if (words.size() < 2) {
257 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
259 auto decorationName =
260 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
261 if (decorationName.empty()) {
262 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
264 auto symbol = getSymbolDecoration(decorationName);
265 switch (
static_cast<spirv::Decoration
>(words[1])) {
266 case spirv::Decoration::FPFastMathMode:
267 if (words.size() != 3) {
268 return emitError(unknownLoc,
"OpDecorate with ")
269 << decorationName <<
" needs a single integer literal";
271 decorations[words[0]].set(
273 static_cast<FPFastMathMode
>(words[2])));
275 case spirv::Decoration::FPRoundingMode:
276 if (words.size() != 3) {
277 return emitError(unknownLoc,
"OpDecorate with ")
278 << decorationName <<
" needs a single integer literal";
280 decorations[words[0]].set(
282 static_cast<FPRoundingMode
>(words[2])));
284 case spirv::Decoration::DescriptorSet:
285 case spirv::Decoration::Binding:
286 if (words.size() != 3) {
287 return emitError(unknownLoc,
"OpDecorate with ")
288 << decorationName <<
" needs a single integer literal";
290 decorations[words[0]].set(
291 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
293 case spirv::Decoration::BuiltIn:
294 if (words.size() != 3) {
295 return emitError(unknownLoc,
"OpDecorate with ")
296 << decorationName <<
" needs a single integer literal";
298 decorations[words[0]].set(
299 symbol, opBuilder.getStringAttr(
300 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
302 case spirv::Decoration::ArrayStride:
303 if (words.size() != 3) {
304 return emitError(unknownLoc,
"OpDecorate with ")
305 << decorationName <<
" needs a single integer literal";
307 typeDecorations[words[0]] = words[2];
309 case spirv::Decoration::LinkageAttributes: {
310 if (words.size() < 4) {
311 return emitError(unknownLoc,
"OpDecorate with ")
313 <<
" needs at least 1 string and 1 integer literal";
321 unsigned wordIndex = 2;
323 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
324 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
325 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
327 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
330 case spirv::Decoration::Aliased:
331 case spirv::Decoration::AliasedPointer:
332 case spirv::Decoration::Block:
333 case spirv::Decoration::BufferBlock:
334 case spirv::Decoration::Flat:
335 case spirv::Decoration::NonReadable:
336 case spirv::Decoration::NonWritable:
337 case spirv::Decoration::NoPerspective:
338 case spirv::Decoration::NoSignedWrap:
339 case spirv::Decoration::NoUnsignedWrap:
340 case spirv::Decoration::RelaxedPrecision:
341 case spirv::Decoration::Restrict:
342 case spirv::Decoration::RestrictPointer:
343 case spirv::Decoration::NoContraction:
344 case spirv::Decoration::Constant:
345 if (words.size() != 2) {
346 return emitError(unknownLoc,
"OpDecoration with ")
347 << decorationName <<
"needs a single target <id>";
353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355 case spirv::Decoration::Location:
356 case spirv::Decoration::SpecId:
357 if (words.size() != 3) {
358 return emitError(unknownLoc,
"OpDecoration with ")
359 << decorationName <<
"needs a single integer literal";
361 decorations[words[0]].set(
362 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
364 case spirv::Decoration::CacheControlLoadINTEL: {
366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
373 case spirv::Decoration::CacheControlStoreINTEL: {
375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
383 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
391 if (words.size() < 3) {
393 "OpMemberDecorate must have at least 3 operands");
396 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
397 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
403 if (words.size() > 3) {
404 decorationOperands = words.slice(3);
406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
411 if (words.size() < 3) {
412 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
414 unsigned wordIndex = 2;
416 if (wordIndex != words.size()) {
418 "unexpected trailing words in OpMemberName instruction");
420 memberNameMap[words[0]][words[1]] = name;
424 LogicalResult spirv::Deserializer::setFunctionArgAttrs(
426 if (!decorations.contains(argID)) {
431 spirv::DecorationAttr foundDecorationAttr;
433 for (
auto decoration :
434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435 spirv::Decoration::AliasedPointer,
436 spirv::Decoration::RestrictPointer}) {
438 if (decAttr.getName() !=
439 getSymbolDecoration(stringifyDecoration(decoration)))
442 if (foundDecorationAttr)
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
453 if (!foundDecorationAttr)
454 return emitError(unknownLoc,
"unimplemented decoration support for "
455 "function argument with result <id> ")
459 foundDecorationAttr);
467 return emitError(unknownLoc,
"found function inside function");
471 if (operands.size() != 4) {
472 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
476 return emitError(unknownLoc,
"undefined result type from <id> ")
480 uint32_t fnID = operands[1];
481 if (funcMap.count(fnID)) {
482 return emitError(unknownLoc,
"duplicate function definition/declaration");
485 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
487 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
491 if (!fnType || !isa<FunctionType>(fnType)) {
492 return emitError(unknownLoc,
"unknown function type from <id> ")
495 auto functionType = cast<FunctionType>(fnType);
497 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
498 (functionType.getNumResults() == 1 &&
499 functionType.getResult(0) != resultType)) {
500 return emitError(unknownLoc,
"mismatch in function type ")
501 << functionType <<
" and return type " << resultType <<
" specified";
504 std::string fnName = getFunctionSymbol(fnID);
505 auto funcOp = opBuilder.create<spirv::FuncOp>(
506 unknownLoc, fnName, functionType, fnControl.value());
508 if (decorations.count(fnID)) {
509 for (
auto attr : decorations[fnID].getAttrs()) {
510 funcOp->setAttr(attr.getName(), attr.getValue());
513 curFunction = funcMap[fnID] = funcOp;
514 auto *entryBlock = funcOp.addEntryBlock();
517 <<
"//===-------------------------------------------===//\n";
518 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
519 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
520 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
521 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
526 argAttrs.resize(functionType.getNumInputs());
529 if (functionType.getNumInputs()) {
530 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
531 auto argType = functionType.getInput(i);
532 spirv::Opcode opcode = spirv::Opcode::OpNop;
534 if (failed(sliceInstruction(opcode, operands,
535 spirv::Opcode::OpFunctionParameter))) {
538 if (opcode != spirv::Opcode::OpFunctionParameter) {
541 "missing OpFunctionParameter instruction for argument ")
544 if (operands.size() != 2) {
547 "expected result type and result <id> for OpFunctionParameter");
549 auto argDefinedType =
getType(operands[0]);
550 if (!argDefinedType || argDefinedType != argType) {
552 "mismatch in argument type between function type "
554 << functionType <<
" and argument type definition "
555 << argDefinedType <<
" at argument " << i;
557 if (getValue(operands[1])) {
558 return emitError(unknownLoc,
"duplicate definition of result <id> ")
561 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
565 auto argValue = funcOp.getArgument(i);
566 valueMap[operands[1]] = argValue;
570 if (llvm::any_of(argAttrs, [](
Attribute attr) {
571 auto argAttr = cast<DictionaryAttr>(attr);
572 return !argAttr.empty();
579 auto linkageAttr = funcOp.getLinkageAttributes();
580 auto hasImportLinkage =
581 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
582 spirv::LinkageType::Import);
583 if (hasImportLinkage)
590 spirv::Opcode opcode = spirv::Opcode::OpNop;
598 if (failed(sliceInstruction(opcode, instOperands,
599 spirv::Opcode::OpFunctionEnd))) {
602 if (opcode == spirv::Opcode::OpFunctionEnd) {
603 return processFunctionEnd(instOperands);
605 if (opcode != spirv::Opcode::OpLabel) {
606 return emitError(unknownLoc,
"a basic block must start with OpLabel");
608 if (instOperands.size() != 1) {
609 return emitError(unknownLoc,
"OpLabel should only have result <id>");
611 blockMap[instOperands[0]] = entryBlock;
612 if (failed(processLabel(instOperands))) {
618 while (succeeded(sliceInstruction(opcode, instOperands,
619 spirv::Opcode::OpFunctionEnd)) &&
620 opcode != spirv::Opcode::OpFunctionEnd) {
621 if (failed(processInstruction(opcode, instOperands))) {
625 if (opcode != spirv::Opcode::OpFunctionEnd) {
629 return processFunctionEnd(instOperands);
635 if (!operands.empty()) {
636 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
642 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
647 curFunction = std::nullopt;
652 <<
"//===-------------------------------------------===//\n";
657 std::optional<std::pair<Attribute, Type>>
658 spirv::Deserializer::getConstant(uint32_t
id) {
659 auto constIt = constantMap.find(
id);
660 if (constIt == constantMap.end())
662 return constIt->getSecond();
665 std::optional<spirv::SpecConstOperationMaterializationInfo>
666 spirv::Deserializer::getSpecConstantOperation(uint32_t
id) {
667 auto constIt = specConstOperationMap.find(
id);
668 if (constIt == specConstOperationMap.end())
670 return constIt->getSecond();
673 std::string spirv::Deserializer::getFunctionSymbol(uint32_t
id) {
674 auto funcName = nameMap.lookup(
id).str();
675 if (funcName.empty()) {
676 funcName =
"spirv_fn_" + std::to_string(
id);
681 std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t
id) {
682 auto constName = nameMap.lookup(
id).str();
683 if (constName.empty()) {
684 constName =
"spirv_spec_const_" + std::to_string(
id);
689 spirv::SpecConstantOp
690 spirv::Deserializer::createSpecConstant(
Location loc, uint32_t resultID,
691 TypedAttr defaultValue) {
692 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
693 auto op = opBuilder.create<spirv::SpecConstantOp>(unknownLoc, symName,
695 if (decorations.count(resultID)) {
696 for (
auto attr : decorations[resultID].getAttrs())
697 op->setAttr(attr.getName(), attr.getValue());
699 specConstMap[resultID] = op;
705 unsigned wordIndex = 0;
706 if (operands.size() < 3) {
709 "OpVariable needs at least 3 operands, type, <id> and storage class");
713 auto type =
getType(operands[wordIndex]);
715 return emitError(unknownLoc,
"unknown result type <id> : ")
716 << operands[wordIndex];
718 auto ptrType = dyn_cast<spirv::PointerType>(type);
721 "expected a result type <id> to be a spirv.ptr, found : ")
727 auto variableID = operands[wordIndex];
728 auto variableName = nameMap.lookup(variableID).str();
729 if (variableName.empty()) {
730 variableName =
"spirv_var_" + std::to_string(variableID);
735 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
736 if (ptrType.getStorageClass() != storageClass) {
737 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
738 << type <<
" and that specified in OpVariable instruction : "
739 << stringifyStorageClass(storageClass);
746 if (wordIndex < operands.size()) {
749 if (
auto initOp = getGlobalVariable(operands[wordIndex]))
751 else if (
auto initOp = getSpecConstant(operands[wordIndex]))
753 else if (
auto initOp = getSpecConstantComposite(operands[wordIndex]))
756 return emitError(unknownLoc,
"unknown <id> ")
757 << operands[wordIndex] <<
"used as initializer";
762 if (wordIndex != operands.size()) {
764 "found more operands than expected when deserializing "
765 "OpVariable instruction, only ")
766 << wordIndex <<
" of " << operands.size() <<
" processed";
768 auto loc = createFileLineColLoc(opBuilder);
769 auto varOp = opBuilder.create<spirv::GlobalVariableOp>(
770 loc,
TypeAttr::get(type), opBuilder.getStringAttr(variableName),
774 if (decorations.count(variableID)) {
775 for (
auto attr : decorations[variableID].getAttrs())
776 varOp->setAttr(attr.getName(), attr.getValue());
778 globalVariableMap[variableID] = varOp;
782 IntegerAttr spirv::Deserializer::getConstantInt(uint32_t
id) {
783 auto constInfo = getConstant(
id);
787 return dyn_cast<IntegerAttr>(constInfo->first);
791 if (operands.size() < 2) {
792 return emitError(unknownLoc,
"OpName needs at least 2 operands");
794 if (!nameMap.lookup(operands[0]).empty()) {
795 return emitError(unknownLoc,
"duplicate name found for result <id> ")
798 unsigned wordIndex = 1;
800 if (wordIndex != operands.size()) {
802 "unexpected trailing words in OpName instruction");
804 nameMap[operands[0]] = name;
812 LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
814 if (operands.empty()) {
815 return emitError(unknownLoc,
"type instruction with opcode ")
816 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
821 if (typeMap.count(operands[0])) {
822 return emitError(unknownLoc,
"duplicate definition for result <id> ")
827 case spirv::Opcode::OpTypeVoid:
828 if (operands.size() != 1)
829 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
830 typeMap[operands[0]] = opBuilder.getNoneType();
832 case spirv::Opcode::OpTypeBool:
833 if (operands.size() != 1)
834 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
835 typeMap[operands[0]] = opBuilder.getI1Type();
837 case spirv::Opcode::OpTypeInt: {
838 if (operands.size() != 3)
840 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
850 : IntegerType::SignednessSemantics::Signless;
853 case spirv::Opcode::OpTypeFloat: {
854 if (operands.size() != 2)
855 return emitError(unknownLoc,
"OpTypeFloat must have bitwidth parameter");
858 switch (operands[1]) {
860 floatTy = opBuilder.getF16Type();
863 floatTy = opBuilder.getF32Type();
866 floatTy = opBuilder.getF64Type();
869 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
872 typeMap[operands[0]] = floatTy;
874 case spirv::Opcode::OpTypeVector: {
875 if (operands.size() != 3) {
878 "OpTypeVector must have element type and count parameters");
882 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
887 case spirv::Opcode::OpTypePointer: {
888 return processOpTypePointer(operands);
890 case spirv::Opcode::OpTypeArray:
891 return processArrayType(operands);
892 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
893 return processCooperativeMatrixTypeKHR(operands);
894 case spirv::Opcode::OpTypeFunction:
895 return processFunctionType(operands);
896 case spirv::Opcode::OpTypeImage:
897 return processImageType(operands);
898 case spirv::Opcode::OpTypeSampledImage:
899 return processSampledImageType(operands);
900 case spirv::Opcode::OpTypeRuntimeArray:
901 return processRuntimeArrayType(operands);
902 case spirv::Opcode::OpTypeStruct:
903 return processStructType(operands);
904 case spirv::Opcode::OpTypeMatrix:
905 return processMatrixType(operands);
907 return emitError(unknownLoc,
"unhandled type instruction");
914 if (operands.size() != 3)
915 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
917 auto pointeeType =
getType(operands[2]);
919 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
922 uint32_t typePointerID = operands[0];
923 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
926 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
927 deferredStructIt != std::end(deferredStructTypesInfos);) {
928 for (
auto *unresolvedMemberIt =
929 std::begin(deferredStructIt->unresolvedMemberTypes);
930 unresolvedMemberIt !=
931 std::end(deferredStructIt->unresolvedMemberTypes);) {
932 if (unresolvedMemberIt->first == typePointerID) {
936 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
937 typeMap[typePointerID];
939 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
941 ++unresolvedMemberIt;
945 if (deferredStructIt->unresolvedMemberTypes.empty()) {
947 auto structType = deferredStructIt->deferredStructType;
949 assert(structType &&
"expected a spirv::StructType");
950 assert(structType.isIdentified() &&
"expected an indentified struct");
952 if (failed(structType.trySetBody(
953 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
954 deferredStructIt->memberDecorationsInfo)))
957 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
968 if (operands.size() != 3) {
970 "OpTypeArray must have element type and count parameters");
975 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
981 auto countInfo = getConstant(operands[2]);
983 return emitError(unknownLoc,
"OpTypeArray count <id> ")
984 << operands[2] <<
"can only come from normal constant right now";
987 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
988 count = intVal.getValue().getZExtValue();
990 return emitError(unknownLoc,
"OpTypeArray count must come from a "
991 "scalar integer constant instruction");
995 elementTy, count, typeDecorations.lookup(operands[0]));
1001 assert(!operands.empty() &&
"No operands for processing function type");
1002 if (operands.size() == 1) {
1003 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1005 auto returnType =
getType(operands[1]);
1007 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1010 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1011 auto ty =
getType(operands[i]);
1013 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1015 argTypes.push_back(ty);
1018 if (!isVoidType(returnType)) {
1025 LogicalResult spirv::Deserializer::processCooperativeMatrixTypeKHR(
1027 if (operands.size() != 6) {
1029 "OpTypeCooperativeMatrixKHR must have element type, "
1030 "scope, row and column parameters, and use");
1036 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1040 std::optional<spirv::Scope> scope =
1041 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1045 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1049 unsigned rows = getConstantInt(operands[3]).getInt();
1050 unsigned columns = getConstantInt(operands[4]).getInt();
1052 std::optional<spirv::CooperativeMatrixUseKHR> use =
1053 spirv::symbolizeCooperativeMatrixUseKHR(
1054 getConstantInt(operands[5]).getInt());
1058 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1062 typeMap[operands[0]] =
1069 if (operands.size() != 2) {
1070 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1075 "OpTypeRuntimeArray references undefined <id> ")
1079 memberType, typeDecorations.lookup(operands[0]));
1087 if (operands.empty()) {
1088 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1091 if (operands.size() == 1) {
1093 typeMap[operands[0]] =
1102 for (
auto op : llvm::drop_begin(operands, 1)) {
1104 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1106 if (!memberType && !typeForwardPtr)
1107 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1111 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1113 memberTypes.push_back(memberType);
1118 if (memberDecorationMap.count(operands[0])) {
1119 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1120 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1121 if (allMemberDecorations.count(memberIndex)) {
1122 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1124 if (memberDecoration.first == spirv::Decoration::Offset) {
1126 if (offsetInfo.empty()) {
1127 offsetInfo.resize(memberTypes.size());
1129 offsetInfo[memberIndex] = memberDecoration.second[0];
1131 if (!memberDecoration.second.empty()) {
1132 memberDecorationsInfo.emplace_back(memberIndex, 1,
1133 memberDecoration.first,
1134 memberDecoration.second[0]);
1136 memberDecorationsInfo.emplace_back(memberIndex, 0,
1137 memberDecoration.first, 0);
1145 uint32_t structID = operands[0];
1146 std::string structIdentifier = nameMap.lookup(structID).str();
1148 if (structIdentifier.empty()) {
1149 assert(unresolvedMemberTypes.empty() &&
1150 "didn't expect unresolved member types");
1155 typeMap[structID] = structTy;
1157 if (!unresolvedMemberTypes.empty())
1158 deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
1159 memberTypes, offsetInfo,
1160 memberDecorationsInfo});
1161 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1162 memberDecorationsInfo)))
1173 if (operands.size() != 3) {
1175 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1176 " (result_id, column_type, and column_count)");
1182 "OpTypeMatrix references undefined column type.")
1186 uint32_t colsCount = operands[2];
1193 if (operands.size() != 2)
1195 "OpTypeForwardPointer instruction must have two operands");
1197 typeForwardPointerIDs.insert(operands[0]);
1207 if (operands.size() != 8)
1210 "OpTypeImage with non-eight operands are not supported yet");
1214 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1217 auto dim = spirv::symbolizeDim(operands[2]);
1219 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1222 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1224 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1227 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1229 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1232 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1234 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1236 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1237 if (!samplerUseInfo)
1238 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1241 auto format = spirv::symbolizeImageFormat(operands[7]);
1243 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1247 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1248 samplingInfo.value(), samplerUseInfo.value(), format.value());
1254 if (operands.size() != 2)
1255 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1260 "OpTypeSampledImage references undefined <id>: ")
1273 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1275 if (operands.size() < 2) {
1277 << opname <<
" must have type <id> and result <id>";
1279 if (operands.size() < 3) {
1281 << opname <<
" must have at least 1 more parameter";
1286 return emitError(unknownLoc,
"undefined result type from <id> ")
1290 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1291 if (bitwidth == 64) {
1292 if (operands.size() == 4) {
1296 << opname <<
" should have 2 parameters for 64-bit values";
1298 if (bitwidth <= 32) {
1299 if (operands.size() == 3) {
1305 <<
" should have 1 parameter for values with no more than 32 bits";
1307 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1311 auto resultID = operands[1];
1313 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1314 auto bitwidth = intType.getWidth();
1315 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1320 if (bitwidth == 64) {
1327 } words = {operands[2], operands[3]};
1328 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1329 }
else if (bitwidth <= 32) {
1330 value = APInt(bitwidth, operands[2],
true,
1334 auto attr = opBuilder.getIntegerAttr(intType, value);
1337 createSpecConstant(unknownLoc, resultID, attr);
1341 constantMap.try_emplace(resultID, attr, intType);
1347 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1348 auto bitwidth = floatType.getWidth();
1349 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1354 if (floatType.isF64()) {
1361 } words = {operands[2], operands[3]};
1362 value = APFloat(llvm::bit_cast<double>(words));
1363 }
else if (floatType.isF32()) {
1364 value = APFloat(llvm::bit_cast<float>(operands[2]));
1365 }
else if (floatType.isF16()) {
1366 APInt data(16, operands[2]);
1367 value = APFloat(APFloat::IEEEhalf(), data);
1370 auto attr = opBuilder.getFloatAttr(floatType, value);
1372 createSpecConstant(unknownLoc, resultID, attr);
1376 constantMap.try_emplace(resultID, attr, floatType);
1382 return emitError(unknownLoc,
"OpConstant can only generate values of "
1383 "scalar integer or floating-point type");
1386 LogicalResult spirv::Deserializer::processConstantBool(
1388 if (operands.size() != 2) {
1390 << (isSpec ?
"Spec" :
"") <<
"Constant"
1391 << (isTrue ?
"True" :
"False")
1392 <<
" must have type <id> and result <id>";
1395 auto attr = opBuilder.getBoolAttr(isTrue);
1396 auto resultID = operands[1];
1398 createSpecConstant(unknownLoc, resultID, attr);
1402 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1410 if (operands.size() < 2) {
1412 "OpConstantComposite must have type <id> and result <id>");
1414 if (operands.size() < 3) {
1416 "OpConstantComposite must have at least 1 parameter");
1421 return emitError(unknownLoc,
"undefined result type from <id> ")
1426 elements.reserve(operands.size() - 2);
1427 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1428 auto elementInfo = getConstant(operands[i]);
1430 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1431 << operands[i] <<
" must come from a normal constant";
1433 elements.push_back(elementInfo->first);
1436 auto resultID = operands[1];
1437 if (
auto vectorType = dyn_cast<VectorType>(resultType)) {
1441 constantMap.try_emplace(resultID, attr, resultType);
1442 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1443 auto attr = opBuilder.getArrayAttr(elements);
1444 constantMap.try_emplace(resultID, attr, resultType);
1446 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1455 if (operands.size() < 2) {
1457 "OpConstantComposite must have type <id> and result <id>");
1459 if (operands.size() < 3) {
1461 "OpConstantComposite must have at least 1 parameter");
1466 return emitError(unknownLoc,
"undefined result type from <id> ")
1470 auto resultID = operands[1];
1471 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1474 elements.reserve(operands.size() - 2);
1475 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1476 auto elementInfo = getSpecConstant(operands[i]);
1480 auto op = opBuilder.
create<spirv::SpecConstantCompositeOp>(
1482 opBuilder.getArrayAttr(elements));
1483 specConstCompositeMap[resultID] = op;
1490 if (operands.size() < 3)
1491 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1492 "result <id>, and operand opcode");
1494 uint32_t resultTypeID = operands[0];
1497 return emitError(unknownLoc,
"undefined result type from <id> ")
1500 uint32_t resultID = operands[1];
1501 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1502 auto emplaceResult = specConstOperationMap.try_emplace(
1504 SpecConstOperationMaterializationInfo{
1505 enclosedOpcode, resultTypeID,
1508 if (!emplaceResult.second)
1509 return emitError(unknownLoc,
"value with <id>: ")
1510 << resultID <<
" is probably defined before.";
1515 Value spirv::Deserializer::materializeSpecConstantOperation(
1516 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1532 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1533 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1536 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1537 enclosedOpResultTypeAndOperands.push_back(fakeID);
1538 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
1539 enclosedOpOperands.end());
1546 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
1553 auto loc = createFileLineColLoc(opBuilder);
1554 auto specConstOperationOp =
1555 opBuilder.create<spirv::SpecConstantOperationOp>(loc, resultType);
1557 Region &body = specConstOperationOp.getBody();
1559 body.
getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
1566 opBuilder.setInsertionPointToEnd(&block);
1568 opBuilder.create<spirv::YieldOp>(loc, block.
front().
getResult(0));
1569 return specConstOperationOp.getResult();
1574 if (operands.size() != 2) {
1576 "OpConstantNull must have type <id> and result <id>");
1581 return emitError(unknownLoc,
"undefined result type from <id> ")
1585 auto resultID = operands[1];
1586 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
1587 auto attr = opBuilder.getZeroAttr(resultType);
1590 constantMap.try_emplace(resultID, attr, resultType);
1594 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
1602 Block *spirv::Deserializer::getOrCreateBlock(uint32_t
id) {
1603 if (
auto *block = getBlock(
id)) {
1604 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
1605 <<
" @ " << block <<
"\n");
1612 auto *block = curFunction->addBlock();
1613 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
1614 <<
" @ " << block <<
"\n");
1615 return blockMap[id] = block;
1620 return emitError(unknownLoc,
"OpBranch must appear inside a block");
1623 if (operands.size() != 1) {
1624 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
1627 auto *target = getOrCreateBlock(operands[0]);
1628 auto loc = createFileLineColLoc(opBuilder);
1632 opBuilder.create<spirv::BranchOp>(loc, target);
1642 "OpBranchConditional must appear inside a block");
1645 if (operands.size() != 3 && operands.size() != 5) {
1647 "OpBranchConditional must have condition, true label, "
1648 "false label, and optionally two branch weights");
1651 auto condition = getValue(operands[0]);
1652 auto *trueBlock = getOrCreateBlock(operands[1]);
1653 auto *falseBlock = getOrCreateBlock(operands[2]);
1655 std::optional<std::pair<uint32_t, uint32_t>> weights;
1656 if (operands.size() == 5) {
1657 weights = std::make_pair(operands[3], operands[4]);
1662 auto loc = createFileLineColLoc(opBuilder);
1663 opBuilder.create<spirv::BranchConditionalOp>(
1664 loc, condition, trueBlock,
1674 return emitError(unknownLoc,
"OpLabel must appear inside a function");
1677 if (operands.size() != 1) {
1678 return emitError(unknownLoc,
"OpLabel should only have result <id>");
1681 auto labelID = operands[0];
1683 auto *block = getOrCreateBlock(labelID);
1684 LLVM_DEBUG(logger.startLine()
1685 <<
"[block] populating block " << block <<
"\n");
1687 assert(block->
empty() &&
"re-deserialize the same block!");
1689 opBuilder.setInsertionPointToStart(block);
1690 blockMap[labelID] = curBlock = block;
1698 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
1701 if (operands.size() < 2) {
1704 "OpSelectionMerge must specify merge target and selection control");
1707 auto *mergeBlock = getOrCreateBlock(operands[0]);
1708 auto loc = createFileLineColLoc(opBuilder);
1709 auto selectionControl = operands[1];
1711 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
1715 "a block cannot have more than one OpSelectionMerge instruction");
1724 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
1727 if (operands.size() < 3) {
1728 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
1729 "continue target and loop control");
1732 auto *mergeBlock = getOrCreateBlock(operands[0]);
1733 auto *continueBlock = getOrCreateBlock(operands[1]);
1734 auto loc = createFileLineColLoc(opBuilder);
1735 uint32_t loopControl = operands[2];
1738 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
1742 "a block cannot have more than one OpLoopMerge instruction");
1750 return emitError(unknownLoc,
"OpPhi must appear in a block");
1753 if (operands.size() < 4) {
1754 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
1755 "and variable-parent pairs");
1760 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
1761 valueMap[operands[1]] = blockArg;
1762 LLVM_DEBUG(logger.startLine()
1763 <<
"[phi] created block argument " << blockArg
1764 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
1768 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
1769 uint32_t value = operands[i];
1770 Block *predecessor = getOrCreateBlock(operands[i + 1]);
1771 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
1772 blockPhiInfo[predecessorTargetPair].
push_back(value);
1773 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
1774 <<
" with arg id = " << value <<
"\n");
1783 class ControlFlowStructurizer {
1786 ControlFlowStructurizer(
Location loc, uint32_t control,
1789 llvm::ScopedPrinter &logger)
1790 : location(loc), control(control), blockMergeInfo(mergeInfo),
1791 headerBlock(header), mergeBlock(merge), continueBlock(cont),
1794 ControlFlowStructurizer(
Location loc, uint32_t control,
1797 : location(loc), control(control), blockMergeInfo(mergeInfo),
1798 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
1808 LogicalResult structurize();
1813 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
1816 spirv::LoopOp createLoopOp(uint32_t loopControl);
1819 void collectBlocksInConstruct();
1828 Block *continueBlock;
1834 llvm::ScopedPrinter &logger;
1840 ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
1843 OpBuilder builder(&mergeBlock->front());
1845 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
1846 auto selectionOp = builder.create<spirv::SelectionOp>(location, control);
1847 selectionOp.addMergeBlock(builder);
1852 spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
1855 OpBuilder builder(&mergeBlock->front());
1857 auto control =
static_cast<spirv::LoopControl
>(loopControl);
1858 auto loopOp = builder.create<spirv::LoopOp>(location, control);
1859 loopOp.addEntryAndMergeBlock(builder);
1864 void ControlFlowStructurizer::collectBlocksInConstruct() {
1865 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
1868 constructBlocks.insert(headerBlock);
1872 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
1873 for (
auto *successor : constructBlocks[i]->getSuccessors())
1874 if (successor != mergeBlock)
1875 constructBlocks.insert(successor);
1879 LogicalResult ControlFlowStructurizer::structurize() {
1881 bool isLoop = continueBlock !=
nullptr;
1883 if (
auto loopOp = createLoopOp(control))
1884 op = loopOp.getOperation();
1886 if (
auto selectionOp = createSelectionOp(control))
1887 op = selectionOp.getOperation();
1896 mapper.
map(mergeBlock, &body.
back());
1898 collectBlocksInConstruct();
1920 for (
auto *block : constructBlocks) {
1923 auto *newBlock = builder.createBlock(&body.
back());
1924 mapper.
map(block, newBlock);
1925 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
1926 <<
" from block " << block <<
"\n");
1930 newBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1931 mapper.
map(blockArg, newArg);
1932 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
1933 << blockArg <<
" to " << newArg <<
"\n");
1936 LLVM_DEBUG(logger.startLine()
1937 <<
"[cf] block " << block <<
" is a function entry block\n");
1940 for (
auto &op : *block)
1941 newBlock->push_back(op.
clone(mapper));
1945 auto remapOperands = [&](
Operation *op) {
1948 operand.set(mappedOp);
1951 succOp.set(mappedOp);
1953 for (
auto &block : body)
1954 block.walk(remapOperands);
1962 headerBlock->replaceAllUsesWith(mergeBlock);
1965 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
1966 headerBlock->getParentOp()->
print(logger.getOStream());
1967 logger.startLine() <<
"\n";
1971 if (!mergeBlock->args_empty()) {
1972 return mergeBlock->getParentOp()->emitError(
1973 "OpPhi in loop merge block unsupported");
1980 mergeBlock->addArgument(blockArg.
getType(), blockArg.
getLoc());
1985 if (!headerBlock->args_empty())
1986 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
1990 builder.setInsertionPointToEnd(&body.front());
1991 builder.create<spirv::BranchOp>(location, mapper.
lookupOrNull(headerBlock),
1997 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2000 for (
auto *block : constructBlocks)
2001 block->dropAllReferences();
2008 for (
auto *block : constructBlocks) {
2012 "failed control flow structurization: it has uses outside of the "
2013 "enclosing selection/loop construct");
2017 for (
auto *block : constructBlocks) {
2026 auto it = blockMergeInfo.find(block);
2027 if (it != blockMergeInfo.end()) {
2033 return emitError(loc,
"failed control flow structurization: nested "
2034 "loop header block should be remapped!");
2036 Block *newContinue = it->second.continueBlock;
2040 return emitError(loc,
"failed control flow structurization: nested "
2041 "loop continue block should be remapped!");
2044 Block *newMerge = it->second.mergeBlock;
2046 newMerge = mappedTo;
2050 blockMergeInfo.
erase(it);
2051 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2060 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2061 <<
" to only contain a spirv.Branch op\n");
2065 builder.setInsertionPointToEnd(block);
2066 builder.create<spirv::BranchOp>(location, mergeBlock);
2068 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2073 LLVM_DEBUG(logger.startLine()
2074 <<
"[cf] after structurizing construct with header block "
2075 << headerBlock <<
":\n"
2081 LogicalResult spirv::Deserializer::wireUpBlockArgument() {
2084 <<
"//----- [phi] start wiring up block arguments -----//\n";
2090 for (
const auto &info : blockPhiInfo) {
2091 Block *block = info.first.first;
2092 Block *target = info.first.second;
2093 const BlockPhiInfo &phiInfo = info.second;
2095 logger.startLine() <<
"[phi] block " << block <<
"\n";
2096 logger.startLine() <<
"[phi] before creating block argument:\n";
2098 logger.startLine() <<
"\n";
2104 opBuilder.setInsertionPoint(op);
2107 blockArgs.reserve(phiInfo.size());
2108 for (uint32_t valueId : phiInfo) {
2109 if (
Value value = getValue(valueId)) {
2110 blockArgs.push_back(value);
2111 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2112 <<
" id = " << valueId <<
"\n");
2114 return emitError(unknownLoc,
"OpPhi references undefined value!");
2118 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2120 opBuilder.create<spirv::BranchOp>(branchOp.getLoc(), branchOp.getTarget(),
2123 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2124 assert((branchCondOp.getTrueBlock() == target ||
2125 branchCondOp.getFalseBlock() == target) &&
2126 "expected target to be either the true or false target");
2127 if (target == branchCondOp.getTrueTarget())
2128 opBuilder.create<spirv::BranchConditionalOp>(
2129 branchCondOp.getLoc(), branchCondOp.getCondition(), blockArgs,
2130 branchCondOp.getFalseBlockArguments(),
2131 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2132 branchCondOp.getFalseTarget());
2134 opBuilder.create<spirv::BranchConditionalOp>(
2135 branchCondOp.getLoc(), branchCondOp.getCondition(),
2136 branchCondOp.getTrueBlockArguments(), blockArgs,
2137 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2138 branchCondOp.getFalseBlock());
2140 branchCondOp.erase();
2142 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2146 logger.startLine() <<
"[phi] after creating block argument:\n";
2148 logger.startLine() <<
"\n";
2151 blockPhiInfo.clear();
2156 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2161 LogicalResult spirv::Deserializer::structurizeControlFlow() {
2164 <<
"//----- [cf] start structurizing control flow -----//\n";
2168 while (!blockMergeInfo.empty()) {
2169 Block *headerBlock = blockMergeInfo.
begin()->first;
2170 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2173 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2174 headerBlock->
print(logger.getOStream());
2175 logger.startLine() <<
"\n";
2178 auto *mergeBlock = mergeInfo.mergeBlock;
2179 assert(mergeBlock &&
"merge block cannot be nullptr");
2180 if (!mergeBlock->args_empty())
2181 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2183 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2184 mergeBlock->print(logger.getOStream());
2185 logger.startLine() <<
"\n";
2188 auto *continueBlock = mergeInfo.continueBlock;
2189 LLVM_DEBUG(
if (continueBlock) {
2190 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2191 continueBlock->print(logger.getOStream());
2192 logger.startLine() <<
"\n";
2196 blockMergeInfo.erase(blockMergeInfo.begin());
2197 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2198 blockMergeInfo, headerBlock,
2199 mergeBlock, continueBlock
2205 if (failed(structurizer.structurize()))
2212 <<
"//--- [cf] completed structurizing control flow ---//\n";
2225 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2226 if (fileName.empty())
2227 fileName =
"<unknown>";
2239 if (operands.size() != 3)
2240 return emitError(unknownLoc,
"OpLine must have 3 operands");
2241 debugLine = DebugLine{operands[0], operands[1], operands[2]};
2245 void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
2249 if (operands.size() < 2)
2250 return emitError(unknownLoc,
"OpString needs at least 2 operands");
2252 if (!debugInfoMap.lookup(operands[0]).empty())
2254 "duplicate debug string found for result <id> ")
2257 unsigned wordIndex = 1;
2259 if (wordIndex != operands.size())
2261 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Location getLoc() const
Return the location for this argument.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
void push_back(Operation *op)
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
StringAttr getStringAttr(const Twine &bytes)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
This is a utility class for mapping one set of IR entities to another.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< BlockOperand > getBlockOperands()
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult deserialize()
Deserializes the remembered SPIR-V binary module.
Deserializer(ArrayRef< uint32_t > binary, MLIRContext *context)
Creates a deserializer for the given SPIR-V binary module.
OwningOpRef< spirv::ModuleOp > collect()
Collects the final SPIR-V ModuleOp.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={})
Construct a literal StructType with at least one member.
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This represents an operation in an abstracted form, suitable for use with the builder APIs.