23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecoration with ")
238 << decorationName <<
"needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc,
"OpDecorate with ")
290 << decorationName <<
" needs a single integer literal";
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc,
"OpDecorate with ")
298 << decorationName <<
" needs a single integer literal";
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc,
"OpDecorate with ")
307 << decorationName <<
" needs a single integer literal";
309 typeDecorations[words[0]] = words[2];
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc,
"OpDecorate with ")
315 <<
" needs at least 1 string and 1 integer literal";
323 unsigned wordIndex = 2;
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 StringAttr::get(context, linkageName), linkageTypeAttr);
329 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 case spirv::Decoration::Coherent:
350 if (words.size() != 2) {
351 return emitError(unknownLoc,
"OpDecoration with ")
352 << decorationName <<
"needs a single target <id>";
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 if (words.size() != 3) {
359 return emitError(unknownLoc,
"OpDecoration with ")
360 << decorationName <<
"needs a single integer literal";
362 decorations[words[0]].set(
363 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
365 case spirv::Decoration::CacheControlLoadINTEL: {
367 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
368 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
374 case spirv::Decoration::CacheControlStoreINTEL: {
376 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
377 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
384 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
390spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
392 if (words.size() < 3) {
394 "OpMemberDecorate must have at least 3 operands");
397 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
398 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
400 " missing offset specification in OpMemberDecorate with "
401 "Offset decoration");
403 ArrayRef<uint32_t> decorationOperands;
404 if (words.size() > 3) {
405 decorationOperands = words.slice(3);
407 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
411LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
412 if (words.size() < 3) {
413 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
415 unsigned wordIndex = 2;
417 if (wordIndex != words.size()) {
419 "unexpected trailing words in OpMemberName instruction");
421 memberNameMap[words[0]][words[1]] = name;
427 if (!decorations.contains(argID)) {
428 argAttrs[argIndex] = DictionaryAttr::get(context, {});
432 spirv::DecorationAttr foundDecorationAttr;
434 for (
auto decoration :
435 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
436 spirv::Decoration::AliasedPointer,
437 spirv::Decoration::RestrictPointer}) {
439 if (decAttr.getName() !=
443 if (foundDecorationAttr)
445 "more than one Aliased/Restrict decorations for "
446 "function argument with result <id> ")
449 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
454 spirv::Decoration::RelaxedPrecision))) {
459 if (foundDecorationAttr)
460 return emitError(unknownLoc,
"already found a decoration for function "
461 "argument with result <id> ")
464 foundDecorationAttr = spirv::DecorationAttr::get(
465 context, spirv::Decoration::RelaxedPrecision);
469 if (!foundDecorationAttr)
470 return emitError(unknownLoc,
"unimplemented decoration support for "
471 "function argument with result <id> ")
474 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
475 foundDecorationAttr);
476 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
483 return emitError(unknownLoc,
"found function inside function");
487 if (operands.size() != 4) {
488 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
492 return emitError(unknownLoc,
"undefined result type from <id> ")
496 uint32_t fnID = operands[1];
497 if (funcMap.count(fnID)) {
498 return emitError(unknownLoc,
"duplicate function definition/declaration");
501 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
503 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
507 if (!fnType || !isa<FunctionType>(fnType)) {
508 return emitError(unknownLoc,
"unknown function type from <id> ")
511 auto functionType = cast<FunctionType>(fnType);
513 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
514 (functionType.getNumResults() == 1 &&
515 functionType.getResult(0) != resultType)) {
516 return emitError(unknownLoc,
"mismatch in function type ")
517 << functionType <<
" and return type " << resultType <<
" specified";
521 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
522 functionType, fnControl.value());
524 if (decorations.count(fnID)) {
525 for (
auto attr : decorations[fnID].getAttrs()) {
526 funcOp->setAttr(attr.getName(), attr.getValue());
529 curFunction = funcMap[fnID] = funcOp;
530 auto *entryBlock = funcOp.addEntryBlock();
533 <<
"//===-------------------------------------------===//\n";
534 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
535 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
536 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
537 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
542 argAttrs.resize(functionType.getNumInputs());
545 if (functionType.getNumInputs()) {
546 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
547 auto argType = functionType.getInput(i);
548 spirv::Opcode opcode = spirv::Opcode::OpNop;
551 spirv::Opcode::OpFunctionParameter))) {
554 if (opcode != spirv::Opcode::OpFunctionParameter) {
557 "missing OpFunctionParameter instruction for argument ")
560 if (operands.size() != 2) {
563 "expected result type and result <id> for OpFunctionParameter");
565 auto argDefinedType =
getType(operands[0]);
566 if (!argDefinedType || argDefinedType != argType) {
568 "mismatch in argument type between function type "
570 << functionType <<
" and argument type definition "
571 << argDefinedType <<
" at argument " << i;
574 return emitError(unknownLoc,
"duplicate definition of result <id> ")
581 auto argValue = funcOp.getArgument(i);
582 valueMap[operands[1]] = argValue;
586 if (llvm::any_of(argAttrs, [](
Attribute attr) {
587 auto argAttr = cast<DictionaryAttr>(attr);
588 return !argAttr.empty();
590 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
595 auto linkageAttr = funcOp.getLinkageAttributes();
596 auto hasImportLinkage =
597 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
598 spirv::LinkageType::Import);
599 if (hasImportLinkage)
606 spirv::Opcode opcode = spirv::Opcode::OpNop;
615 spirv::Opcode::OpFunctionEnd))) {
618 if (opcode == spirv::Opcode::OpFunctionEnd) {
621 if (opcode != spirv::Opcode::OpLabel) {
622 return emitError(unknownLoc,
"a basic block must start with OpLabel");
624 if (instOperands.size() != 1) {
625 return emitError(unknownLoc,
"OpLabel should only have result <id>");
627 blockMap[instOperands[0]] = entryBlock;
635 spirv::Opcode::OpFunctionEnd)) &&
636 opcode != spirv::Opcode::OpFunctionEnd) {
641 if (opcode != spirv::Opcode::OpFunctionEnd) {
651 if (!operands.empty()) {
652 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
663 curFunction = std::nullopt;
668 <<
"//===-------------------------------------------===//\n";
675 if (operands.size() < 2) {
677 "missing graph defintion in OpGraphEntryPointARM");
680 unsigned wordIndex = 0;
681 uint32_t graphID = operands[wordIndex++];
682 if (!graphMap.contains(graphID)) {
684 "missing graph definition/declaration with id ")
688 spirv::GraphARMOp graphARM = graphMap[graphID];
690 graphARM.setSymName(name);
691 graphARM.setEntryPoint(
true);
694 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
696 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
698 return emitError(unknownLoc,
"undefined result <id> ")
699 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
705 opBuilder.setInsertionPoint(graphARM);
706 spirv::GraphEntryPointARMOp::create(
707 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
708 opBuilder.getArrayAttr(interface));
716 return emitError(unknownLoc,
"found graph inside graph");
719 if (operands.size() < 2) {
720 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
724 if (!type || !isa<GraphType>(type)) {
725 return emitError(unknownLoc,
"unknown graph type from <id> ")
728 auto graphType = cast<GraphType>(type);
729 if (graphType.getNumResults() <= 0) {
730 return emitError(unknownLoc,
"expected at least one result");
733 uint32_t graphID = operands[1];
734 if (graphMap.count(graphID)) {
735 return emitError(unknownLoc,
"duplicate graph definition/declaration");
740 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
741 curGraph = graphMap[graphID] = graphOp;
742 Block *entryBlock = graphOp.addEntryBlock();
745 <<
"//===-------------------------------------------===//\n";
746 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
747 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
748 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
749 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
754 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
755 spirv::Opcode opcode;
758 spirv::Opcode::OpGraphInputARM))) {
761 if (operands.size() != 3) {
762 return emitError(unknownLoc,
"expected result type, result <id> and "
763 "input index for OpGraphInputARM");
767 if (!argDefinedType) {
768 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
771 if (argDefinedType != argType) {
773 "mismatch in argument type between graph type "
775 << graphType <<
" and argument type definition " << argDefinedType
776 <<
" at argument " <<
index;
779 return emitError(unknownLoc,
"duplicate definition of result <id> ")
784 if (!inputIndexAttr) {
786 "unable to read inputIndex value from constant op ")
789 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
790 valueMap[operands[1]] = argValue;
793 graphOutputs.resize(graphType.getNumResults());
799 blockMap[graphID] = entryBlock;
806 spirv::Opcode opcode;
816 }
while (opcode != spirv::Opcode::OpGraphEndARM);
823 if (operands.size() != 2) {
826 "expected value id and output index for OpGraphSetOutputARM");
829 uint32_t
id = operands[0];
832 return emitError(unknownLoc,
"could not find result <id> ") << id;
836 if (!outputIndexAttr) {
838 "unable to read outputIndex value from constant op ")
841 graphOutputs[outputIndexAttr.getInt()] = value;
848 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
851 if (!operands.empty()) {
852 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
856 curGraph = std::nullopt;
857 graphOutputs.clear();
862 <<
"//===-------------------------------------------===//\n";
867std::optional<std::pair<Attribute, Type>>
869 auto constIt = constantMap.find(
id);
870 if (constIt == constantMap.end())
872 return constIt->getSecond();
875std::optional<std::pair<Attribute, Type>>
877 if (
auto it = constantCompositeReplicateMap.find(
id);
878 it != constantCompositeReplicateMap.end())
883std::optional<spirv::SpecConstOperationMaterializationInfo>
885 auto constIt = specConstOperationMap.find(
id);
886 if (constIt == specConstOperationMap.end())
888 return constIt->getSecond();
892 auto funcName = nameMap.lookup(
id).str();
893 if (funcName.empty()) {
894 funcName =
"spirv_fn_" + std::to_string(
id);
900 std::string graphName = nameMap.lookup(
id).str();
901 if (graphName.empty()) {
902 graphName =
"spirv_graph_" + std::to_string(
id);
908 auto constName = nameMap.lookup(
id).str();
909 if (constName.empty()) {
910 constName =
"spirv_spec_const_" + std::to_string(
id);
917 TypedAttr defaultValue) {
919 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
921 if (decorations.count(resultID)) {
922 for (
auto attr : decorations[resultID].getAttrs())
923 op->setAttr(attr.getName(), attr.getValue());
925 specConstMap[resultID] = op;
929std::optional<spirv::GraphConstantARMOpMaterializationInfo>
931 auto graphConstIt = graphConstantMap.find(
id);
932 if (graphConstIt == graphConstantMap.end())
934 return graphConstIt->getSecond();
939 unsigned wordIndex = 0;
940 if (operands.size() < 3) {
943 "OpVariable needs at least 3 operands, type, <id> and storage class");
947 auto type =
getType(operands[wordIndex]);
949 return emitError(unknownLoc,
"unknown result type <id> : ")
950 << operands[wordIndex];
952 auto ptrType = dyn_cast<spirv::PointerType>(type);
955 "expected a result type <id> to be a spirv.ptr, found : ")
961 auto variableID = operands[wordIndex];
962 auto variableName = nameMap.lookup(variableID).str();
963 if (variableName.empty()) {
964 variableName =
"spirv_var_" + std::to_string(variableID);
969 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
970 if (ptrType.getStorageClass() != storageClass) {
971 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
972 << type <<
" and that specified in OpVariable instruction : "
973 << stringifyStorageClass(storageClass);
980 if (wordIndex < operands.size()) {
990 return emitError(unknownLoc,
"unknown <id> ")
991 << operands[wordIndex] <<
"used as initializer";
993 initializer = SymbolRefAttr::get(op);
996 if (wordIndex != operands.size()) {
998 "found more operands than expected when deserializing "
999 "OpVariable instruction, only ")
1000 << wordIndex <<
" of " << operands.size() <<
" processed";
1003 auto varOp = spirv::GlobalVariableOp::create(
1004 opBuilder, loc, TypeAttr::get(type),
1005 opBuilder.getStringAttr(variableName), initializer);
1008 if (decorations.count(variableID)) {
1009 for (
auto attr : decorations[variableID].getAttrs())
1010 varOp->setAttr(attr.getName(), attr.getValue());
1012 globalVariableMap[variableID] = varOp;
1021 return dyn_cast<IntegerAttr>(constInfo->first);
1025 if (operands.size() < 2) {
1026 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1028 if (!nameMap.lookup(operands[0]).empty()) {
1029 return emitError(unknownLoc,
"duplicate name found for result <id> ")
1032 unsigned wordIndex = 1;
1034 if (wordIndex != operands.size()) {
1036 "unexpected trailing words in OpName instruction");
1038 nameMap[operands[0]] = name;
1048 if (operands.empty()) {
1049 return emitError(unknownLoc,
"type instruction with opcode ")
1050 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1055 if (typeMap.count(operands[0])) {
1056 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1061 case spirv::Opcode::OpTypeVoid:
1062 if (operands.size() != 1)
1063 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1064 typeMap[operands[0]] = opBuilder.getNoneType();
1066 case spirv::Opcode::OpTypeBool:
1067 if (operands.size() != 1)
1068 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1069 typeMap[operands[0]] = opBuilder.getI1Type();
1071 case spirv::Opcode::OpTypeInt: {
1072 if (operands.size() != 3)
1074 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1083 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1084 : IntegerType::SignednessSemantics::Signless;
1085 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1087 case spirv::Opcode::OpTypeFloat: {
1088 if (operands.size() != 2 && operands.size() != 3)
1090 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1091 "or 3 operands (type, bitwidth, encoding), but got ")
1093 uint32_t bitWidth = operands[1];
1098 floatTy = opBuilder.getF16Type();
1101 floatTy = opBuilder.getF32Type();
1104 floatTy = opBuilder.getF64Type();
1107 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1111 if (operands.size() == 3) {
1112 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1113 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1117 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1118 << bitWidth <<
" (expected 16)";
1119 floatTy = opBuilder.getBF16Type();
1122 typeMap[operands[0]] = floatTy;
1124 case spirv::Opcode::OpTypeVector: {
1125 if (operands.size() != 3) {
1128 "OpTypeVector must have element type and count parameters");
1132 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1135 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1137 case spirv::Opcode::OpTypePointer: {
1140 case spirv::Opcode::OpTypeArray:
1142 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1144 case spirv::Opcode::OpTypeFunction:
1146 case spirv::Opcode::OpTypeImage:
1148 case spirv::Opcode::OpTypeSampledImage:
1150 case spirv::Opcode::OpTypeRuntimeArray:
1152 case spirv::Opcode::OpTypeStruct:
1154 case spirv::Opcode::OpTypeMatrix:
1156 case spirv::Opcode::OpTypeTensorARM:
1158 case spirv::Opcode::OpTypeGraphARM:
1161 return emitError(unknownLoc,
"unhandled type instruction");
1168 if (operands.size() != 3)
1169 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1171 auto pointeeType =
getType(operands[2]);
1173 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1176 uint32_t typePointerID = operands[0];
1177 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1180 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1181 deferredStructIt != std::end(deferredStructTypesInfos);) {
1182 for (
auto *unresolvedMemberIt =
1183 std::begin(deferredStructIt->unresolvedMemberTypes);
1184 unresolvedMemberIt !=
1185 std::end(deferredStructIt->unresolvedMemberTypes);) {
1186 if (unresolvedMemberIt->first == typePointerID) {
1190 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1191 typeMap[typePointerID];
1192 unresolvedMemberIt =
1193 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1195 ++unresolvedMemberIt;
1199 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1201 auto structType = deferredStructIt->deferredStructType;
1203 assert(structType &&
"expected a spirv::StructType");
1204 assert(structType.isIdentified() &&
"expected an indentified struct");
1206 if (failed(structType.trySetBody(
1207 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1208 deferredStructIt->memberDecorationsInfo,
1209 deferredStructIt->structDecorationsInfo)))
1212 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1223 if (operands.size() != 3) {
1225 "OpTypeArray must have element type and count parameters");
1230 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1238 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1239 << operands[2] <<
"can only come from normal constant right now";
1242 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1243 count = intVal.getValue().getZExtValue();
1245 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1246 "scalar integer constant instruction");
1250 elementTy, count, typeDecorations.lookup(operands[0]));
1256 assert(!operands.empty() &&
"No operands for processing function type");
1257 if (operands.size() == 1) {
1258 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1260 auto returnType =
getType(operands[1]);
1262 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1265 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1266 auto ty =
getType(operands[i]);
1268 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1270 argTypes.push_back(ty);
1276 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1282 if (operands.size() != 6) {
1284 "OpTypeCooperativeMatrixKHR must have element type, "
1285 "scope, row and column parameters, and use");
1291 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1295 std::optional<spirv::Scope> scope =
1300 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1309 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1310 "undefined constant <id> ")
1314 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1315 "references undefined constant <id> ")
1319 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1320 "undefined constant <id> ")
1323 unsigned rows = rowsAttr.getInt();
1324 unsigned columns = columnsAttr.getInt();
1326 std::optional<spirv::CooperativeMatrixUseKHR> use =
1327 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1331 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1335 typeMap[operands[0]] =
1342 if (operands.size() != 2) {
1343 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1348 "OpTypeRuntimeArray references undefined <id> ")
1352 memberType, typeDecorations.lookup(operands[0]));
1360 if (operands.empty()) {
1361 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1364 if (operands.size() == 1) {
1366 typeMap[operands[0]] =
1375 for (
auto op : llvm::drop_begin(operands, 1)) {
1377 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1379 if (!memberType && !typeForwardPtr)
1380 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1384 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1386 memberTypes.push_back(memberType);
1391 if (memberDecorationMap.count(operands[0])) {
1392 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1393 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1394 if (allMemberDecorations.count(memberIndex)) {
1395 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1397 if (memberDecoration.first == spirv::Decoration::Offset) {
1399 if (offsetInfo.empty()) {
1400 offsetInfo.resize(memberTypes.size());
1402 offsetInfo[memberIndex] = memberDecoration.second[0];
1404 auto intType = mlir::IntegerType::get(context, 32);
1405 if (!memberDecoration.second.empty()) {
1406 memberDecorationsInfo.emplace_back(
1407 memberIndex, memberDecoration.first,
1408 IntegerAttr::get(intType, memberDecoration.second[0]));
1410 memberDecorationsInfo.emplace_back(
1411 memberIndex, memberDecoration.first, UnitAttr::get(context));
1420 if (decorations.count(operands[0])) {
1423 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1424 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1425 assert(decoration.has_value());
1426 structDecorationsInfo.emplace_back(decoration.value(),
1427 decorationAttr.getValue());
1431 uint32_t structID = operands[0];
1432 std::string structIdentifier = nameMap.lookup(structID).str();
1434 if (structIdentifier.empty()) {
1435 assert(unresolvedMemberTypes.empty() &&
1436 "didn't expect unresolved member types");
1438 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1441 typeMap[structID] = structTy;
1443 if (!unresolvedMemberTypes.empty())
1444 deferredStructTypesInfos.push_back(
1445 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1446 memberDecorationsInfo, structDecorationsInfo});
1447 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1448 memberDecorationsInfo,
1449 structDecorationsInfo)))
1460 if (operands.size() != 3) {
1462 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1463 " (result_id, column_type, and column_count)");
1469 "OpTypeMatrix references undefined column type.")
1473 uint32_t colsCount = operands[2];
1480 unsigned size = operands.size();
1481 if (size < 2 || size > 4)
1482 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1483 "(result_id, element_type, (rank), (shape)) ")
1489 "OpTypeTensorARM references undefined element type ")
1499 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1500 "scalar integer constant instruction");
1501 unsigned rank = rankAttr.getValue().getZExtValue();
1508 std::optional<std::pair<Attribute, Type>> shapeInfo =
1511 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1512 "constant instruction of type OpTypeArray");
1514 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1516 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1517 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1519 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1521 shape.push_back(dimIntAttr.getValue().getSExtValue());
1529 unsigned size = operands.size();
1531 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1532 "(result_id, num_inputs, (inout0_type, "
1533 "inout1_type, ...))")
1536 uint32_t numInputs = operands[1];
1539 for (
unsigned i = 2; i < size; ++i) {
1543 "OpTypeGraphARM references undefined element type.")
1546 if (i - 2 >= numInputs) {
1547 returnTypes.push_back(inOutTy);
1549 argTypes.push_back(inOutTy);
1552 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1558 if (operands.size() != 2)
1560 "OpTypeForwardPointer instruction must have two operands");
1562 typeForwardPointerIDs.insert(operands[0]);
1572 if (operands.size() != 8)
1575 "OpTypeImage with non-eight operands are not supported yet");
1579 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1582 auto dim = spirv::symbolizeDim(operands[2]);
1584 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1587 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1589 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1592 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1594 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1597 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1599 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1601 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1602 if (!samplerUseInfo)
1603 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1606 auto format = spirv::symbolizeImageFormat(operands[7]);
1608 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1612 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1613 samplingInfo.value(), samplerUseInfo.value(), format.value());
1619 if (operands.size() != 2)
1620 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1625 "OpTypeSampledImage references undefined <id>: ")
1638 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1640 if (operands.size() < 2) {
1642 << opname <<
" must have type <id> and result <id>";
1644 if (operands.size() < 3) {
1646 << opname <<
" must have at least 1 more parameter";
1651 return emitError(unknownLoc,
"undefined result type from <id> ")
1655 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1656 if (bitwidth == 64) {
1657 if (operands.size() == 4) {
1661 << opname <<
" should have 2 parameters for 64-bit values";
1663 if (bitwidth <= 32) {
1664 if (operands.size() == 3) {
1670 <<
" should have 1 parameter for values with no more than 32 bits";
1672 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1676 auto resultID = operands[1];
1678 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1679 auto bitwidth = intType.getWidth();
1680 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1685 if (bitwidth == 64) {
1692 } words = {operands[2], operands[3]};
1693 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1694 }
else if (bitwidth <= 32) {
1695 value = APInt(bitwidth, operands[2],
true,
1699 auto attr = opBuilder.getIntegerAttr(intType, value);
1706 constantMap.try_emplace(resultID, attr, intType);
1712 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1713 auto bitwidth = floatType.getWidth();
1714 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1719 if (floatType.isF64()) {
1726 } words = {operands[2], operands[3]};
1727 value = APFloat(llvm::bit_cast<double>(words));
1728 }
else if (floatType.isF32()) {
1729 value = APFloat(llvm::bit_cast<float>(operands[2]));
1730 }
else if (floatType.isF16()) {
1731 APInt data(16, operands[2]);
1732 value = APFloat(APFloat::IEEEhalf(), data);
1733 }
else if (floatType.isBF16()) {
1734 APInt data(16, operands[2]);
1735 value = APFloat(APFloat::BFloat(), data);
1738 auto attr = opBuilder.getFloatAttr(floatType, value);
1744 constantMap.try_emplace(resultID, attr, floatType);
1750 return emitError(unknownLoc,
"OpConstant can only generate values of "
1751 "scalar integer or floating-point type");
1756 if (operands.size() != 2) {
1758 << (isSpec ?
"Spec" :
"") <<
"Constant"
1759 << (isTrue ?
"True" :
"False")
1760 <<
" must have type <id> and result <id>";
1763 auto attr = opBuilder.getBoolAttr(isTrue);
1764 auto resultID = operands[1];
1770 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1778 if (operands.size() < 2) {
1780 "OpConstantComposite must have type <id> and result <id>");
1782 if (operands.size() < 3) {
1784 "OpConstantComposite must have at least 1 parameter");
1789 return emitError(unknownLoc,
"undefined result type from <id> ")
1794 elements.reserve(operands.size() - 2);
1795 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1798 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1799 << operands[i] <<
" must come from a normal constant";
1801 elements.push_back(elementInfo->first);
1804 auto resultID = operands[1];
1805 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1808 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1809 for (
auto value : denseElemAttr.getValues<
Attribute>())
1810 flattenedElems.push_back(value);
1812 flattenedElems.push_back(element);
1816 constantMap.try_emplace(resultID, attr, tensorType);
1817 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1821 constantMap.try_emplace(resultID, attr, shapedType);
1822 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1823 auto attr = opBuilder.getArrayAttr(elements);
1824 constantMap.try_emplace(resultID, attr, resultType);
1826 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1835 if (operands.size() != 3) {
1838 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1844 return emitError(unknownLoc,
"undefined result type from <id> ")
1848 auto compositeType = dyn_cast<CompositeType>(resultType);
1849 if (!compositeType) {
1851 "result type from <id> is not a composite type")
1855 uint32_t resultID = operands[1];
1856 uint32_t constantID = operands[2];
1858 std::optional<std::pair<Attribute, Type>> constantInfo =
1860 if (constantInfo.has_value()) {
1861 constantCompositeReplicateMap.try_emplace(
1862 resultID, constantInfo.value().first, resultType);
1866 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1868 if (replicatedConstantCompositeInfo.has_value()) {
1869 constantCompositeReplicateMap.try_emplace(
1870 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1874 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1876 <<
" must come from a normal constant or a "
1877 "OpConstantCompositeReplicateEXT";
1882 if (operands.size() < 2) {
1885 "OpSpecConstantComposite must have type <id> and result <id>");
1887 if (operands.size() < 3) {
1889 "OpSpecConstantComposite must have at least 1 parameter");
1894 return emitError(unknownLoc,
"undefined result type from <id> ")
1898 auto resultID = operands[1];
1902 elements.reserve(operands.size() - 2);
1903 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1905 elements.push_back(SymbolRefAttr::get(elementInfo));
1908 auto op = spirv::SpecConstantCompositeOp::create(
1909 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1910 opBuilder.getArrayAttr(elements));
1911 specConstCompositeMap[resultID] = op;
1918 if (operands.size() != 3) {
1919 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1920 "3 operands but found ")
1926 return emitError(unknownLoc,
"undefined result type from <id> ")
1930 auto compositeType = dyn_cast<CompositeType>(resultType);
1931 if (!compositeType) {
1933 "result type from <id> is not a composite type")
1937 uint32_t resultID = operands[1];
1940 spirv::SpecConstantOp constituentSpecConstantOp =
1942 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1943 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1944 SymbolRefAttr::get(constituentSpecConstantOp));
1946 specConstCompositeReplicateMap[resultID] = op;
1953 if (operands.size() < 3)
1954 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1955 "result <id>, and operand opcode");
1957 uint32_t resultTypeID = operands[0];
1960 return emitError(unknownLoc,
"undefined result type from <id> ")
1963 uint32_t resultID = operands[1];
1964 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1965 auto emplaceResult = specConstOperationMap.try_emplace(
1968 enclosedOpcode, resultTypeID,
1971 if (!emplaceResult.second)
1972 return emitError(unknownLoc,
"value with <id>: ")
1973 << resultID <<
" is probably defined before.";
1979 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1995 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1996 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1999 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2000 enclosedOpResultTypeAndOperands.push_back(fakeID);
2001 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2002 enclosedOpOperands.end());
2017 auto specConstOperationOp =
2018 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2020 Region &body = specConstOperationOp.getBody();
2022 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2029 opBuilder.setInsertionPointToEnd(&block);
2031 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2032 return specConstOperationOp.getResult();
2037 if (operands.size() != 2) {
2039 "OpConstantNull must only have type <id> and result <id>");
2044 return emitError(unknownLoc,
"undefined result type from <id> ")
2048 auto resultID = operands[1];
2050 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2051 attr = opBuilder.getZeroAttr(resultType);
2052 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2053 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2060 constantMap.try_emplace(resultID, attr, resultType);
2064 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2070 if (operands.size() < 3) {
2072 <<
"OpGraphConstantARM must have at least 2 operands";
2077 return emitError(unknownLoc,
"undefined result type from <id> ")
2081 uint32_t resultID = operands[1];
2083 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2084 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2087 APInt graph_constant_id = APInt(32, operands[2],
true);
2088 Type i32Ty = opBuilder.getIntegerType(32);
2089 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2090 graphConstantMap.try_emplace(
2102 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2103 <<
" @ " << block <<
"\n");
2110 auto *block = curFunction->addBlock();
2111 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2112 <<
" @ " << block <<
"\n");
2113 return blockMap[id] = block;
2118 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2121 if (operands.size() != 1) {
2122 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2130 spirv::BranchOp::create(opBuilder, loc,
target);
2140 "OpBranchConditional must appear inside a block");
2143 if (operands.size() != 3 && operands.size() != 5) {
2145 "OpBranchConditional must have condition, true label, "
2146 "false label, and optionally two branch weights");
2149 auto condition =
getValue(operands[0]);
2153 std::optional<std::pair<uint32_t, uint32_t>> weights;
2154 if (operands.size() == 5) {
2155 weights = std::make_pair(operands[3], operands[4]);
2161 spirv::BranchConditionalOp::create(
2162 opBuilder, loc, condition, trueBlock,
2172 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2175 if (operands.size() != 1) {
2176 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2179 auto labelID = operands[0];
2182 LLVM_DEBUG(logger.startLine()
2183 <<
"[block] populating block " << block <<
"\n");
2185 assert(block->empty() &&
"re-deserialize the same block!");
2187 opBuilder.setInsertionPointToStart(block);
2188 blockMap[labelID] = curBlock = block;
2195 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2200 LLVM_DEBUG(logger.startLine()
2201 <<
"[block] populating block " << block <<
"\n");
2203 assert(block->
empty() &&
"re-deserialize the same block!");
2205 opBuilder.setInsertionPointToStart(block);
2206 blockMap[graphID] = curBlock = block;
2214 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2217 if (operands.size() < 2) {
2220 "OpSelectionMerge must specify merge target and selection control");
2225 auto selectionControl = operands[1];
2227 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2231 "a block cannot have more than one OpSelectionMerge instruction");
2240 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2243 if (operands.size() < 3) {
2244 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2245 "continue target and loop control");
2251 uint32_t loopControl = operands[2];
2254 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2258 "a block cannot have more than one OpLoopMerge instruction");
2266 return emitError(unknownLoc,
"OpPhi must appear in a block");
2269 if (operands.size() < 4) {
2270 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2271 "and variable-parent pairs");
2276 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2277 valueMap[operands[1]] = blockArg;
2278 LLVM_DEBUG(logger.startLine()
2279 <<
"[phi] created block argument " << blockArg
2280 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2284 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2285 uint32_t value = operands[i];
2287 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2288 blockPhiInfo[predecessorTargetPair].push_back(value);
2289 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2290 <<
" with arg id = " << value <<
"\n");
2298 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2300 if (operands.size() < 2)
2301 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2302 "a default target");
2304 if (operands.size() % 2)
2306 "OpSwitch must at have an even number of operands: "
2307 "selector, default target and any number of literal and "
2308 "label <id> pairs");
2316 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2317 literals.push_back(operands[i]);
2322 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2331class ControlFlowStructurizer {
2334 ControlFlowStructurizer(
Location loc, uint32_t control,
2337 llvm::ScopedPrinter &logger)
2338 : location(loc), control(control), blockMergeInfo(mergeInfo),
2339 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2342 ControlFlowStructurizer(
Location loc, uint32_t control,
2345 : location(loc), control(control), blockMergeInfo(mergeInfo),
2346 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2356 LogicalResult structurize();
2361 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2364 spirv::LoopOp createLoopOp(uint32_t loopControl);
2367 void collectBlocksInConstruct();
2376 Block *continueBlock;
2382 llvm::ScopedPrinter &logger;
2388ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2391 OpBuilder builder(&mergeBlock->front());
2393 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2394 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2395 selectionOp.addMergeBlock(builder);
2400spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2403 OpBuilder builder(&mergeBlock->front());
2405 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2406 auto loopOp = spirv::LoopOp::create(builder, location, control);
2407 loopOp.addEntryAndMergeBlock(builder);
2412void ControlFlowStructurizer::collectBlocksInConstruct() {
2413 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2416 constructBlocks.insert(headerBlock);
2420 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2421 for (
auto *successor : constructBlocks[i]->getSuccessors())
2422 if (successor != mergeBlock)
2423 constructBlocks.insert(successor);
2427LogicalResult ControlFlowStructurizer::structurize() {
2428 Operation *op =
nullptr;
2429 bool isLoop = continueBlock !=
nullptr;
2431 if (
auto loopOp = createLoopOp(control))
2432 op = loopOp.getOperation();
2434 if (
auto selectionOp = createSelectionOp(control))
2435 op = selectionOp.getOperation();
2444 mapper.
map(mergeBlock, &body.
back());
2446 collectBlocksInConstruct();
2467 OpBuilder builder(body);
2468 for (
auto *block : constructBlocks) {
2471 auto *newBlock = builder.createBlock(&body.
back());
2472 mapper.
map(block, newBlock);
2473 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2474 <<
" from block " << block <<
"\n");
2476 for (BlockArgument blockArg : block->getArguments()) {
2478 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2479 mapper.
map(blockArg, newArg);
2480 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2481 << blockArg <<
" to " << newArg <<
"\n");
2484 LLVM_DEBUG(logger.startLine()
2485 <<
"[cf] block " << block <<
" is a function entry block\n");
2488 for (
auto &op : *block)
2489 newBlock->push_back(op.
clone(mapper));
2493 auto remapOperands = [&](Operation *op) {
2495 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2496 operand.set(mappedOp);
2499 succOp.set(mappedOp);
2501 for (
auto &block : body)
2502 block.walk(remapOperands);
2510 headerBlock->replaceAllUsesWith(mergeBlock);
2513 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2514 headerBlock->getParentOp()->print(logger.getOStream());
2515 logger.startLine() <<
"\n";
2519 if (!mergeBlock->args_empty()) {
2520 return mergeBlock->getParentOp()->emitError(
2521 "OpPhi in loop merge block unsupported");
2527 for (BlockArgument blockArg : headerBlock->getArguments())
2528 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2532 SmallVector<Value, 4> blockArgs;
2533 if (!headerBlock->args_empty())
2534 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2538 builder.setInsertionPointToEnd(&body.front());
2539 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2540 ArrayRef<Value>(blockArgs));
2545 SmallVector<Value> valuesToYield;
2548 SmallVector<Value> outsideUses;
2562 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2567 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2568 valuesToYield.push_back(body.back().getArguments().back());
2569 outsideUses.push_back(blockArg);
2574 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2577 for (
auto *block : constructBlocks)
2578 block->dropAllReferences();
2583 for (
Block *block : constructBlocks) {
2584 for (Operation &op : *block) {
2588 outsideUses.push_back(
result);
2591 for (BlockArgument &arg : block->getArguments()) {
2592 if (!arg.use_empty()) {
2594 outsideUses.push_back(arg);
2599 assert(valuesToYield.size() == outsideUses.size());
2603 if (!valuesToYield.empty()) {
2604 LLVM_DEBUG(logger.startLine()
2605 <<
"[cf] yielding values from the selection / loop region\n");
2608 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2609 Operation *merge = llvm::getSingleElement(mergeOps);
2611 merge->setOperands(valuesToYield);
2619 builder.setInsertionPoint(&mergeBlock->front());
2621 Operation *newOp =
nullptr;
2624 newOp = spirv::LoopOp::create(builder, location,
2626 static_cast<spirv::LoopControl
>(control));
2628 newOp = spirv::SelectionOp::create(
2630 static_cast<spirv::SelectionControl
>(control));
2640 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2641 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2647 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2654 for (
auto *block : constructBlocks) {
2655 if (!block->use_empty())
2656 return emitError(block->getParent()->getLoc(),
2657 "failed control flow structurization: "
2658 "block has uses outside of the "
2659 "enclosing selection/loop construct");
2660 for (Operation &op : *block)
2662 return op.
emitOpError(
"failed control flow structurization: value has "
2663 "uses outside of the "
2664 "enclosing selection/loop construct");
2665 for (BlockArgument &arg : block->getArguments())
2666 if (!arg.use_empty())
2667 return emitError(arg.getLoc(),
"failed control flow structurization: "
2668 "block argument has uses outside of the "
2669 "enclosing selection/loop construct");
2673 for (
auto *block : constructBlocks) {
2713 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2714 auto it = blockMergeInfo.find(block);
2715 if (it != blockMergeInfo.end()) {
2717 Location loc = it->second.loc;
2721 return emitError(loc,
"failed control flow structurization: nested "
2722 "loop header block should be remapped!");
2724 Block *newContinue = it->second.continueBlock;
2728 return emitError(loc,
"failed control flow structurization: nested "
2729 "loop continue block should be remapped!");
2732 Block *newMerge = it->second.mergeBlock;
2734 newMerge = mappedTo;
2738 blockMergeInfo.
erase(it);
2739 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2746 if (block->walk(updateMergeInfo).wasInterrupted())
2754 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2755 <<
" to only contain a spirv.Branch op\n");
2759 builder.setInsertionPointToEnd(block);
2760 spirv::BranchOp::create(builder, location, mergeBlock);
2762 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2767 LLVM_DEBUG(logger.startLine()
2768 <<
"[cf] after structurizing construct with header block "
2769 << headerBlock <<
":\n"
2778 <<
"//----- [phi] start wiring up block arguments -----//\n";
2784 for (
const auto &info : blockPhiInfo) {
2785 Block *block = info.first.first;
2789 logger.startLine() <<
"[phi] block " << block <<
"\n";
2790 logger.startLine() <<
"[phi] before creating block argument:\n";
2792 logger.startLine() <<
"\n";
2798 opBuilder.setInsertionPoint(op);
2801 blockArgs.reserve(phiInfo.size());
2802 for (uint32_t valueId : phiInfo) {
2804 blockArgs.push_back(value);
2805 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2806 <<
" id = " << valueId <<
"\n");
2808 return emitError(unknownLoc,
"OpPhi references undefined value!");
2812 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2814 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2815 branchOp.getTarget(), blockArgs);
2817 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2818 assert((branchCondOp.getTrueBlock() ==
target ||
2819 branchCondOp.getFalseBlock() ==
target) &&
2820 "expected target to be either the true or false target");
2821 if (
target == branchCondOp.getTrueTarget())
2822 spirv::BranchConditionalOp::create(
2823 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2824 blockArgs, branchCondOp.getFalseBlockArguments(),
2825 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2826 branchCondOp.getFalseTarget());
2828 spirv::BranchConditionalOp::create(
2829 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2830 branchCondOp.getTrueBlockArguments(), blockArgs,
2831 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2832 branchCondOp.getFalseBlock());
2834 branchCondOp.erase();
2835 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2836 if (
target == switchOp.getDefaultTarget()) {
2840 spirv::SwitchOp::create(
2841 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2842 switchOp.getDefaultTarget(), blockArgs, literals,
2843 switchOp.getTargets(), targetOperands);
2847 auto it = llvm::find(targets,
target);
2848 assert(it != targets.end());
2849 size_t index = std::distance(targets.begin(), it);
2850 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2853 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2857 logger.startLine() <<
"[phi] after creating block argument:\n";
2859 logger.startLine() <<
"\n";
2862 blockPhiInfo.clear();
2867 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2875 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2877 auto &[block, mergeInfo] = *it;
2880 if (mergeInfo.continueBlock)
2883 if (!block->mightHaveTerminator())
2886 Operation *terminator = block->getTerminator();
2889 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2893 bool splitHeaderMergeBlock =
false;
2894 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2895 if (mergeInfo.mergeBlock == block)
2896 splitHeaderMergeBlock =
true;
2903 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2906 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2910 blockMergeInfo.erase(block);
2911 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2919 if (!options.enableControlFlowStructurization) {
2923 <<
"//----- [cf] skip structurizing control flow -----//\n";
2931 <<
"//----- [cf] start structurizing control flow -----//\n";
2936 logger.startLine() <<
"[cf] split conditional blocks\n";
2937 logger.startLine() <<
"\n";
2944 while (!blockMergeInfo.empty()) {
2945 Block *headerBlock = blockMergeInfo.
begin()->first;
2949 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2950 headerBlock->
print(logger.getOStream());
2951 logger.startLine() <<
"\n";
2955 assert(mergeBlock &&
"merge block cannot be nullptr");
2957 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2959 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2960 mergeBlock->print(logger.getOStream());
2961 logger.startLine() <<
"\n";
2965 LLVM_DEBUG(
if (continueBlock) {
2966 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2967 continueBlock->print(logger.getOStream());
2968 logger.startLine() <<
"\n";
2972 blockMergeInfo.
erase(blockMergeInfo.begin());
2973 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2974 blockMergeInfo, headerBlock,
2975 mergeBlock, continueBlock
2981 if (failed(structurizer.structurize()))
2988 <<
"//--- [cf] completed structurizing control flow ---//\n";
3001 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3002 if (fileName.empty())
3003 fileName =
"<unknown>";
3015 if (operands.size() != 3)
3016 return emitError(unknownLoc,
"OpLine must have 3 operands");
3017 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3025 if (operands.size() < 2)
3026 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3028 if (!debugInfoMap.lookup(operands[0]).empty())
3030 "duplicate debug string found for result <id> ")
3033 unsigned wordIndex = 1;
3035 if (wordIndex != operands.size())
3037 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.