23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecoration with ")
238 << decorationName <<
"needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 llvm::dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc,
"OpDecorate with ")
290 << decorationName <<
" needs a single integer literal";
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc,
"OpDecorate with ")
298 << decorationName <<
" needs a single integer literal";
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc,
"OpDecorate with ")
307 << decorationName <<
" needs a single integer literal";
309 typeDecorations[words[0]] = words[2];
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc,
"OpDecorate with ")
315 <<
" needs at least 1 string and 1 integer literal";
323 unsigned wordIndex = 2;
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 StringAttr::get(context, linkageName), linkageTypeAttr);
329 decorations[words[0]].set(symbol, llvm::dyn_cast<Attribute>(linkageAttr));
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 if (words.size() != 2) {
350 return emitError(unknownLoc,
"OpDecoration with ")
351 << decorationName <<
"needs a single target <id>";
353 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355 case spirv::Decoration::Location:
356 case spirv::Decoration::SpecId:
357 if (words.size() != 3) {
358 return emitError(unknownLoc,
"OpDecoration with ")
359 << decorationName <<
"needs a single integer literal";
361 decorations[words[0]].set(
362 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
364 case spirv::Decoration::CacheControlLoadINTEL: {
366 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
367 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
373 case spirv::Decoration::CacheControlStoreINTEL: {
375 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
376 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
383 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
389spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
391 if (words.size() < 3) {
393 "OpMemberDecorate must have at least 3 operands");
396 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
397 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
399 " missing offset specification in OpMemberDecorate with "
400 "Offset decoration");
402 ArrayRef<uint32_t> decorationOperands;
403 if (words.size() > 3) {
404 decorationOperands = words.slice(3);
406 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
410LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
411 if (words.size() < 3) {
412 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
414 unsigned wordIndex = 2;
416 if (wordIndex != words.size()) {
418 "unexpected trailing words in OpMemberName instruction");
420 memberNameMap[words[0]][words[1]] = name;
426 if (!decorations.contains(argID)) {
427 argAttrs[argIndex] = DictionaryAttr::get(context, {});
431 spirv::DecorationAttr foundDecorationAttr;
433 for (
auto decoration :
434 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
435 spirv::Decoration::AliasedPointer,
436 spirv::Decoration::RestrictPointer}) {
438 if (decAttr.getName() !=
442 if (foundDecorationAttr)
444 "more than one Aliased/Restrict decorations for "
445 "function argument with result <id> ")
448 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
453 spirv::Decoration::RelaxedPrecision))) {
458 if (foundDecorationAttr)
459 return emitError(unknownLoc,
"already found a decoration for function "
460 "argument with result <id> ")
463 foundDecorationAttr = spirv::DecorationAttr::get(
464 context, spirv::Decoration::RelaxedPrecision);
468 if (!foundDecorationAttr)
469 return emitError(unknownLoc,
"unimplemented decoration support for "
470 "function argument with result <id> ")
473 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
474 foundDecorationAttr);
475 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
482 return emitError(unknownLoc,
"found function inside function");
486 if (operands.size() != 4) {
487 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
491 return emitError(unknownLoc,
"undefined result type from <id> ")
495 uint32_t fnID = operands[1];
496 if (funcMap.count(fnID)) {
497 return emitError(unknownLoc,
"duplicate function definition/declaration");
500 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
502 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
506 if (!fnType || !isa<FunctionType>(fnType)) {
507 return emitError(unknownLoc,
"unknown function type from <id> ")
510 auto functionType = cast<FunctionType>(fnType);
512 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
513 (functionType.getNumResults() == 1 &&
514 functionType.getResult(0) != resultType)) {
515 return emitError(unknownLoc,
"mismatch in function type ")
516 << functionType <<
" and return type " << resultType <<
" specified";
520 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
521 functionType, fnControl.value());
523 if (decorations.count(fnID)) {
524 for (
auto attr : decorations[fnID].getAttrs()) {
525 funcOp->setAttr(attr.getName(), attr.getValue());
528 curFunction = funcMap[fnID] = funcOp;
529 auto *entryBlock = funcOp.addEntryBlock();
532 <<
"//===-------------------------------------------===//\n";
533 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
534 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
535 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
536 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
541 argAttrs.resize(functionType.getNumInputs());
544 if (functionType.getNumInputs()) {
545 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
546 auto argType = functionType.getInput(i);
547 spirv::Opcode opcode = spirv::Opcode::OpNop;
550 spirv::Opcode::OpFunctionParameter))) {
553 if (opcode != spirv::Opcode::OpFunctionParameter) {
556 "missing OpFunctionParameter instruction for argument ")
559 if (operands.size() != 2) {
562 "expected result type and result <id> for OpFunctionParameter");
564 auto argDefinedType =
getType(operands[0]);
565 if (!argDefinedType || argDefinedType != argType) {
567 "mismatch in argument type between function type "
569 << functionType <<
" and argument type definition "
570 << argDefinedType <<
" at argument " << i;
573 return emitError(unknownLoc,
"duplicate definition of result <id> ")
580 auto argValue = funcOp.getArgument(i);
581 valueMap[operands[1]] = argValue;
585 if (llvm::any_of(argAttrs, [](
Attribute attr) {
586 auto argAttr = cast<DictionaryAttr>(attr);
587 return !argAttr.empty();
589 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
594 auto linkageAttr = funcOp.getLinkageAttributes();
595 auto hasImportLinkage =
596 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
597 spirv::LinkageType::Import);
598 if (hasImportLinkage)
605 spirv::Opcode opcode = spirv::Opcode::OpNop;
614 spirv::Opcode::OpFunctionEnd))) {
617 if (opcode == spirv::Opcode::OpFunctionEnd) {
620 if (opcode != spirv::Opcode::OpLabel) {
621 return emitError(unknownLoc,
"a basic block must start with OpLabel");
623 if (instOperands.size() != 1) {
624 return emitError(unknownLoc,
"OpLabel should only have result <id>");
626 blockMap[instOperands[0]] = entryBlock;
634 spirv::Opcode::OpFunctionEnd)) &&
635 opcode != spirv::Opcode::OpFunctionEnd) {
640 if (opcode != spirv::Opcode::OpFunctionEnd) {
650 if (!operands.empty()) {
651 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
662 curFunction = std::nullopt;
667 <<
"//===-------------------------------------------===//\n";
674 if (operands.size() < 2) {
676 "missing graph defintion in OpGraphEntryPointARM");
679 unsigned wordIndex = 0;
680 uint32_t graphID = operands[wordIndex++];
681 if (!graphMap.contains(graphID)) {
683 "missing graph definition/declaration with id ")
687 spirv::GraphARMOp graphARM = graphMap[graphID];
689 graphARM.setSymName(name);
690 graphARM.setEntryPoint(
true);
693 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
695 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
697 return emitError(unknownLoc,
"undefined result <id> ")
698 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
704 opBuilder.setInsertionPoint(graphARM);
705 spirv::GraphEntryPointARMOp::create(
706 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
707 opBuilder.getArrayAttr(interface));
715 return emitError(unknownLoc,
"found graph inside graph");
718 if (operands.size() < 2) {
719 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
723 if (!type || !isa<GraphType>(type)) {
724 return emitError(unknownLoc,
"unknown graph type from <id> ")
727 auto graphType = cast<GraphType>(type);
728 if (graphType.getNumResults() <= 0) {
729 return emitError(unknownLoc,
"expected at least one result");
732 uint32_t graphID = operands[1];
733 if (graphMap.count(graphID)) {
734 return emitError(unknownLoc,
"duplicate graph definition/declaration");
739 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
740 curGraph = graphMap[graphID] = graphOp;
741 Block *entryBlock = graphOp.addEntryBlock();
744 <<
"//===-------------------------------------------===//\n";
745 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
746 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
747 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
748 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
753 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
754 spirv::Opcode opcode;
757 spirv::Opcode::OpGraphInputARM))) {
760 if (operands.size() != 3) {
761 return emitError(unknownLoc,
"expected result type, result <id> and "
762 "input index for OpGraphInputARM");
766 if (!argDefinedType) {
767 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
770 if (argDefinedType != argType) {
772 "mismatch in argument type between graph type "
774 << graphType <<
" and argument type definition " << argDefinedType
775 <<
" at argument " <<
index;
778 return emitError(unknownLoc,
"duplicate definition of result <id> ")
783 if (!inputIndexAttr) {
785 "unable to read inputIndex value from constant op ")
788 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
789 valueMap[operands[1]] = argValue;
792 graphOutputs.resize(graphType.getNumResults());
798 blockMap[graphID] = entryBlock;
805 spirv::Opcode opcode;
815 }
while (opcode != spirv::Opcode::OpGraphEndARM);
822 if (operands.size() != 2) {
825 "expected value id and output index for OpGraphSetOutputARM");
828 uint32_t
id = operands[0];
831 return emitError(unknownLoc,
"could not find result <id> ") << id;
835 if (!outputIndexAttr) {
837 "unable to read outputIndex value from constant op ")
840 graphOutputs[outputIndexAttr.getInt()] = value;
847 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
850 if (!operands.empty()) {
851 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
855 curGraph = std::nullopt;
856 graphOutputs.clear();
861 <<
"//===-------------------------------------------===//\n";
866std::optional<std::pair<Attribute, Type>>
868 auto constIt = constantMap.find(
id);
869 if (constIt == constantMap.end())
871 return constIt->getSecond();
874std::optional<std::pair<Attribute, Type>>
876 if (
auto it = constantCompositeReplicateMap.find(
id);
877 it != constantCompositeReplicateMap.end())
882std::optional<spirv::SpecConstOperationMaterializationInfo>
884 auto constIt = specConstOperationMap.find(
id);
885 if (constIt == specConstOperationMap.end())
887 return constIt->getSecond();
891 auto funcName = nameMap.lookup(
id).str();
892 if (funcName.empty()) {
893 funcName =
"spirv_fn_" + std::to_string(
id);
899 std::string graphName = nameMap.lookup(
id).str();
900 if (graphName.empty()) {
901 graphName =
"spirv_graph_" + std::to_string(
id);
907 auto constName = nameMap.lookup(
id).str();
908 if (constName.empty()) {
909 constName =
"spirv_spec_const_" + std::to_string(
id);
916 TypedAttr defaultValue) {
918 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
920 if (decorations.count(resultID)) {
921 for (
auto attr : decorations[resultID].getAttrs())
922 op->setAttr(attr.getName(), attr.getValue());
924 specConstMap[resultID] = op;
928std::optional<spirv::GraphConstantARMOpMaterializationInfo>
930 auto graphConstIt = graphConstantMap.find(
id);
931 if (graphConstIt == graphConstantMap.end())
933 return graphConstIt->getSecond();
938 unsigned wordIndex = 0;
939 if (operands.size() < 3) {
942 "OpVariable needs at least 3 operands, type, <id> and storage class");
946 auto type =
getType(operands[wordIndex]);
948 return emitError(unknownLoc,
"unknown result type <id> : ")
949 << operands[wordIndex];
951 auto ptrType = dyn_cast<spirv::PointerType>(type);
954 "expected a result type <id> to be a spirv.ptr, found : ")
960 auto variableID = operands[wordIndex];
961 auto variableName = nameMap.lookup(variableID).str();
962 if (variableName.empty()) {
963 variableName =
"spirv_var_" + std::to_string(variableID);
968 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
969 if (ptrType.getStorageClass() != storageClass) {
970 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
971 << type <<
" and that specified in OpVariable instruction : "
972 << stringifyStorageClass(storageClass);
979 if (wordIndex < operands.size()) {
989 return emitError(unknownLoc,
"unknown <id> ")
990 << operands[wordIndex] <<
"used as initializer";
992 initializer = SymbolRefAttr::get(op);
995 if (wordIndex != operands.size()) {
997 "found more operands than expected when deserializing "
998 "OpVariable instruction, only ")
999 << wordIndex <<
" of " << operands.size() <<
" processed";
1002 auto varOp = spirv::GlobalVariableOp::create(
1003 opBuilder, loc, TypeAttr::get(type),
1004 opBuilder.getStringAttr(variableName), initializer);
1007 if (decorations.count(variableID)) {
1008 for (
auto attr : decorations[variableID].getAttrs())
1009 varOp->setAttr(attr.getName(), attr.getValue());
1011 globalVariableMap[variableID] = varOp;
1020 return dyn_cast<IntegerAttr>(constInfo->first);
1024 if (operands.size() < 2) {
1025 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1027 if (!nameMap.lookup(operands[0]).empty()) {
1028 return emitError(unknownLoc,
"duplicate name found for result <id> ")
1031 unsigned wordIndex = 1;
1033 if (wordIndex != operands.size()) {
1035 "unexpected trailing words in OpName instruction");
1037 nameMap[operands[0]] = name;
1047 if (operands.empty()) {
1048 return emitError(unknownLoc,
"type instruction with opcode ")
1049 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1054 if (typeMap.count(operands[0])) {
1055 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1060 case spirv::Opcode::OpTypeVoid:
1061 if (operands.size() != 1)
1062 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1063 typeMap[operands[0]] = opBuilder.getNoneType();
1065 case spirv::Opcode::OpTypeBool:
1066 if (operands.size() != 1)
1067 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1068 typeMap[operands[0]] = opBuilder.getI1Type();
1070 case spirv::Opcode::OpTypeInt: {
1071 if (operands.size() != 3)
1073 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1082 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1083 : IntegerType::SignednessSemantics::Signless;
1084 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1086 case spirv::Opcode::OpTypeFloat: {
1087 if (operands.size() != 2 && operands.size() != 3)
1089 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1090 "or 3 operands (type, bitwidth, encoding), but got ")
1092 uint32_t bitWidth = operands[1];
1097 floatTy = opBuilder.getF16Type();
1100 floatTy = opBuilder.getF32Type();
1103 floatTy = opBuilder.getF64Type();
1106 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1110 if (operands.size() == 3) {
1111 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1112 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1116 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1117 << bitWidth <<
" (expected 16)";
1118 floatTy = opBuilder.getBF16Type();
1121 typeMap[operands[0]] = floatTy;
1123 case spirv::Opcode::OpTypeVector: {
1124 if (operands.size() != 3) {
1127 "OpTypeVector must have element type and count parameters");
1131 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1134 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1136 case spirv::Opcode::OpTypePointer: {
1139 case spirv::Opcode::OpTypeArray:
1141 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1143 case spirv::Opcode::OpTypeFunction:
1145 case spirv::Opcode::OpTypeImage:
1147 case spirv::Opcode::OpTypeSampledImage:
1149 case spirv::Opcode::OpTypeRuntimeArray:
1151 case spirv::Opcode::OpTypeStruct:
1153 case spirv::Opcode::OpTypeMatrix:
1155 case spirv::Opcode::OpTypeTensorARM:
1157 case spirv::Opcode::OpTypeGraphARM:
1160 return emitError(unknownLoc,
"unhandled type instruction");
1167 if (operands.size() != 3)
1168 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1170 auto pointeeType =
getType(operands[2]);
1172 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1175 uint32_t typePointerID = operands[0];
1176 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1179 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1180 deferredStructIt != std::end(deferredStructTypesInfos);) {
1181 for (
auto *unresolvedMemberIt =
1182 std::begin(deferredStructIt->unresolvedMemberTypes);
1183 unresolvedMemberIt !=
1184 std::end(deferredStructIt->unresolvedMemberTypes);) {
1185 if (unresolvedMemberIt->first == typePointerID) {
1189 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1190 typeMap[typePointerID];
1191 unresolvedMemberIt =
1192 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1194 ++unresolvedMemberIt;
1198 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1200 auto structType = deferredStructIt->deferredStructType;
1202 assert(structType &&
"expected a spirv::StructType");
1203 assert(structType.isIdentified() &&
"expected an indentified struct");
1205 if (failed(structType.trySetBody(
1206 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1207 deferredStructIt->memberDecorationsInfo,
1208 deferredStructIt->structDecorationsInfo)))
1211 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1222 if (operands.size() != 3) {
1224 "OpTypeArray must have element type and count parameters");
1229 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1237 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1238 << operands[2] <<
"can only come from normal constant right now";
1241 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1242 count = intVal.getValue().getZExtValue();
1244 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1245 "scalar integer constant instruction");
1249 elementTy, count, typeDecorations.lookup(operands[0]));
1255 assert(!operands.empty() &&
"No operands for processing function type");
1256 if (operands.size() == 1) {
1257 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1259 auto returnType =
getType(operands[1]);
1261 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1264 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1265 auto ty =
getType(operands[i]);
1267 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1269 argTypes.push_back(ty);
1275 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1281 if (operands.size() != 6) {
1283 "OpTypeCooperativeMatrixKHR must have element type, "
1284 "scope, row and column parameters, and use");
1290 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1294 std::optional<spirv::Scope> scope =
1299 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1308 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1309 "undefined constant <id> ")
1313 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1314 "references undefined constant <id> ")
1318 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1319 "undefined constant <id> ")
1322 unsigned rows = rowsAttr.getInt();
1323 unsigned columns = columnsAttr.getInt();
1325 std::optional<spirv::CooperativeMatrixUseKHR> use =
1326 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1330 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1334 typeMap[operands[0]] =
1341 if (operands.size() != 2) {
1342 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1347 "OpTypeRuntimeArray references undefined <id> ")
1351 memberType, typeDecorations.lookup(operands[0]));
1359 if (operands.empty()) {
1360 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1363 if (operands.size() == 1) {
1365 typeMap[operands[0]] =
1374 for (
auto op : llvm::drop_begin(operands, 1)) {
1376 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1378 if (!memberType && !typeForwardPtr)
1379 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1383 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1385 memberTypes.push_back(memberType);
1390 if (memberDecorationMap.count(operands[0])) {
1391 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1392 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1393 if (allMemberDecorations.count(memberIndex)) {
1394 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1396 if (memberDecoration.first == spirv::Decoration::Offset) {
1398 if (offsetInfo.empty()) {
1399 offsetInfo.resize(memberTypes.size());
1401 offsetInfo[memberIndex] = memberDecoration.second[0];
1403 auto intType = mlir::IntegerType::get(context, 32);
1404 if (!memberDecoration.second.empty()) {
1405 memberDecorationsInfo.emplace_back(
1406 memberIndex, memberDecoration.first,
1407 IntegerAttr::get(intType, memberDecoration.second[0]));
1409 memberDecorationsInfo.emplace_back(
1410 memberIndex, memberDecoration.first, UnitAttr::get(context));
1419 if (decorations.count(operands[0])) {
1422 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1423 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1424 assert(decoration.has_value());
1425 structDecorationsInfo.emplace_back(decoration.value(),
1426 decorationAttr.getValue());
1430 uint32_t structID = operands[0];
1431 std::string structIdentifier = nameMap.lookup(structID).str();
1433 if (structIdentifier.empty()) {
1434 assert(unresolvedMemberTypes.empty() &&
1435 "didn't expect unresolved member types");
1437 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1440 typeMap[structID] = structTy;
1442 if (!unresolvedMemberTypes.empty())
1443 deferredStructTypesInfos.push_back(
1444 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1445 memberDecorationsInfo, structDecorationsInfo});
1446 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1447 memberDecorationsInfo,
1448 structDecorationsInfo)))
1459 if (operands.size() != 3) {
1461 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1462 " (result_id, column_type, and column_count)");
1468 "OpTypeMatrix references undefined column type.")
1472 uint32_t colsCount = operands[2];
1479 unsigned size = operands.size();
1480 if (size < 2 || size > 4)
1481 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1482 "(result_id, element_type, (rank), (shape)) ")
1488 "OpTypeTensorARM references undefined element type ")
1498 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1499 "scalar integer constant instruction");
1500 unsigned rank = rankAttr.getValue().getZExtValue();
1507 std::optional<std::pair<Attribute, Type>> shapeInfo =
1510 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1511 "constant instruction of type OpTypeArray");
1513 ArrayAttr shapeArrayAttr = llvm::dyn_cast<ArrayAttr>(shapeInfo->first);
1515 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1516 auto dimIntAttr = llvm::dyn_cast<IntegerAttr>(dimAttr);
1518 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1520 shape.push_back(dimIntAttr.getValue().getSExtValue());
1528 unsigned size = operands.size();
1530 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1531 "(result_id, num_inputs, (inout0_type, "
1532 "inout1_type, ...))")
1535 uint32_t numInputs = operands[1];
1538 for (
unsigned i = 2; i < size; ++i) {
1542 "OpTypeGraphARM references undefined element type.")
1545 if (i - 2 >= numInputs) {
1546 returnTypes.push_back(inOutTy);
1548 argTypes.push_back(inOutTy);
1551 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1557 if (operands.size() != 2)
1559 "OpTypeForwardPointer instruction must have two operands");
1561 typeForwardPointerIDs.insert(operands[0]);
1571 if (operands.size() != 8)
1574 "OpTypeImage with non-eight operands are not supported yet");
1578 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1581 auto dim = spirv::symbolizeDim(operands[2]);
1583 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1586 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1588 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1591 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1593 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1596 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1598 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1600 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1601 if (!samplerUseInfo)
1602 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1605 auto format = spirv::symbolizeImageFormat(operands[7]);
1607 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1611 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1612 samplingInfo.value(), samplerUseInfo.value(), format.value());
1618 if (operands.size() != 2)
1619 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1624 "OpTypeSampledImage references undefined <id>: ")
1637 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1639 if (operands.size() < 2) {
1641 << opname <<
" must have type <id> and result <id>";
1643 if (operands.size() < 3) {
1645 << opname <<
" must have at least 1 more parameter";
1650 return emitError(unknownLoc,
"undefined result type from <id> ")
1654 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1655 if (bitwidth == 64) {
1656 if (operands.size() == 4) {
1660 << opname <<
" should have 2 parameters for 64-bit values";
1662 if (bitwidth <= 32) {
1663 if (operands.size() == 3) {
1669 <<
" should have 1 parameter for values with no more than 32 bits";
1671 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1675 auto resultID = operands[1];
1677 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1678 auto bitwidth = intType.getWidth();
1679 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1684 if (bitwidth == 64) {
1691 } words = {operands[2], operands[3]};
1692 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1693 }
else if (bitwidth <= 32) {
1694 value = APInt(bitwidth, operands[2],
true,
1698 auto attr = opBuilder.getIntegerAttr(intType, value);
1705 constantMap.try_emplace(resultID, attr, intType);
1711 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1712 auto bitwidth = floatType.getWidth();
1713 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1718 if (floatType.isF64()) {
1725 } words = {operands[2], operands[3]};
1726 value = APFloat(llvm::bit_cast<double>(words));
1727 }
else if (floatType.isF32()) {
1728 value = APFloat(llvm::bit_cast<float>(operands[2]));
1729 }
else if (floatType.isF16()) {
1730 APInt data(16, operands[2]);
1731 value = APFloat(APFloat::IEEEhalf(), data);
1732 }
else if (floatType.isBF16()) {
1733 APInt data(16, operands[2]);
1734 value = APFloat(APFloat::BFloat(), data);
1737 auto attr = opBuilder.getFloatAttr(floatType, value);
1743 constantMap.try_emplace(resultID, attr, floatType);
1749 return emitError(unknownLoc,
"OpConstant can only generate values of "
1750 "scalar integer or floating-point type");
1755 if (operands.size() != 2) {
1757 << (isSpec ?
"Spec" :
"") <<
"Constant"
1758 << (isTrue ?
"True" :
"False")
1759 <<
" must have type <id> and result <id>";
1762 auto attr = opBuilder.getBoolAttr(isTrue);
1763 auto resultID = operands[1];
1769 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1777 if (operands.size() < 2) {
1779 "OpConstantComposite must have type <id> and result <id>");
1781 if (operands.size() < 3) {
1783 "OpConstantComposite must have at least 1 parameter");
1788 return emitError(unknownLoc,
"undefined result type from <id> ")
1793 elements.reserve(operands.size() - 2);
1794 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1797 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1798 << operands[i] <<
" must come from a normal constant";
1800 elements.push_back(elementInfo->first);
1803 auto resultID = operands[1];
1804 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1807 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1808 for (
auto value : denseElemAttr.getValues<
Attribute>())
1809 flattenedElems.push_back(value);
1811 flattenedElems.push_back(element);
1815 constantMap.try_emplace(resultID, attr, tensorType);
1816 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1820 constantMap.try_emplace(resultID, attr, shapedType);
1821 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1822 auto attr = opBuilder.getArrayAttr(elements);
1823 constantMap.try_emplace(resultID, attr, resultType);
1825 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1834 if (operands.size() != 3) {
1837 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1843 return emitError(unknownLoc,
"undefined result type from <id> ")
1847 auto compositeType = dyn_cast<CompositeType>(resultType);
1848 if (!compositeType) {
1850 "result type from <id> is not a composite type")
1854 uint32_t resultID = operands[1];
1855 uint32_t constantID = operands[2];
1857 std::optional<std::pair<Attribute, Type>> constantInfo =
1859 if (constantInfo.has_value()) {
1860 constantCompositeReplicateMap.try_emplace(
1861 resultID, constantInfo.value().first, resultType);
1865 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1867 if (replicatedConstantCompositeInfo.has_value()) {
1868 constantCompositeReplicateMap.try_emplace(
1869 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1873 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1875 <<
" must come from a normal constant or a "
1876 "OpConstantCompositeReplicateEXT";
1881 if (operands.size() < 2) {
1884 "OpSpecConstantComposite must have type <id> and result <id>");
1886 if (operands.size() < 3) {
1888 "OpSpecConstantComposite must have at least 1 parameter");
1893 return emitError(unknownLoc,
"undefined result type from <id> ")
1897 auto resultID = operands[1];
1901 elements.reserve(operands.size() - 2);
1902 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1904 elements.push_back(SymbolRefAttr::get(elementInfo));
1907 auto op = spirv::SpecConstantCompositeOp::create(
1908 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1909 opBuilder.getArrayAttr(elements));
1910 specConstCompositeMap[resultID] = op;
1917 if (operands.size() != 3) {
1918 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1919 "3 operands but found ")
1925 return emitError(unknownLoc,
"undefined result type from <id> ")
1929 auto compositeType = dyn_cast<CompositeType>(resultType);
1930 if (!compositeType) {
1932 "result type from <id> is not a composite type")
1936 uint32_t resultID = operands[1];
1939 spirv::SpecConstantOp constituentSpecConstantOp =
1941 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1942 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1943 SymbolRefAttr::get(constituentSpecConstantOp));
1945 specConstCompositeReplicateMap[resultID] = op;
1952 if (operands.size() < 3)
1953 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1954 "result <id>, and operand opcode");
1956 uint32_t resultTypeID = operands[0];
1959 return emitError(unknownLoc,
"undefined result type from <id> ")
1962 uint32_t resultID = operands[1];
1963 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1964 auto emplaceResult = specConstOperationMap.try_emplace(
1967 enclosedOpcode, resultTypeID,
1970 if (!emplaceResult.second)
1971 return emitError(unknownLoc,
"value with <id>: ")
1972 << resultID <<
" is probably defined before.";
1978 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1994 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1995 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
1998 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
1999 enclosedOpResultTypeAndOperands.push_back(fakeID);
2000 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2001 enclosedOpOperands.end());
2016 auto specConstOperationOp =
2017 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2019 Region &body = specConstOperationOp.getBody();
2021 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2028 opBuilder.setInsertionPointToEnd(&block);
2030 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2031 return specConstOperationOp.getResult();
2036 if (operands.size() != 2) {
2038 "OpConstantNull must only have type <id> and result <id>");
2043 return emitError(unknownLoc,
"undefined result type from <id> ")
2047 auto resultID = operands[1];
2049 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2050 attr = opBuilder.getZeroAttr(resultType);
2051 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2052 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2059 constantMap.try_emplace(resultID, attr, resultType);
2063 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2069 if (operands.size() < 3) {
2071 <<
"OpGraphConstantARM must have at least 2 operands";
2076 return emitError(unknownLoc,
"undefined result type from <id> ")
2080 uint32_t resultID = operands[1];
2082 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2086 APInt graph_constant_id = APInt(32, operands[2],
true);
2087 Type i32Ty = opBuilder.getIntegerType(32);
2088 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2089 graphConstantMap.try_emplace(
2101 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2102 <<
" @ " << block <<
"\n");
2109 auto *block = curFunction->addBlock();
2110 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2111 <<
" @ " << block <<
"\n");
2112 return blockMap[id] = block;
2117 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2120 if (operands.size() != 1) {
2121 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2129 spirv::BranchOp::create(opBuilder, loc,
target);
2139 "OpBranchConditional must appear inside a block");
2142 if (operands.size() != 3 && operands.size() != 5) {
2144 "OpBranchConditional must have condition, true label, "
2145 "false label, and optionally two branch weights");
2148 auto condition =
getValue(operands[0]);
2152 std::optional<std::pair<uint32_t, uint32_t>> weights;
2153 if (operands.size() == 5) {
2154 weights = std::make_pair(operands[3], operands[4]);
2160 spirv::BranchConditionalOp::create(
2161 opBuilder, loc, condition, trueBlock,
2171 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2174 if (operands.size() != 1) {
2175 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2178 auto labelID = operands[0];
2181 LLVM_DEBUG(logger.startLine()
2182 <<
"[block] populating block " << block <<
"\n");
2184 assert(block->empty() &&
"re-deserialize the same block!");
2186 opBuilder.setInsertionPointToStart(block);
2187 blockMap[labelID] = curBlock = block;
2194 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2199 LLVM_DEBUG(logger.startLine()
2200 <<
"[block] populating block " << block <<
"\n");
2202 assert(block->
empty() &&
"re-deserialize the same block!");
2204 opBuilder.setInsertionPointToStart(block);
2205 blockMap[graphID] = curBlock = block;
2213 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2216 if (operands.size() < 2) {
2219 "OpSelectionMerge must specify merge target and selection control");
2224 auto selectionControl = operands[1];
2226 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2230 "a block cannot have more than one OpSelectionMerge instruction");
2239 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2242 if (operands.size() < 3) {
2243 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2244 "continue target and loop control");
2250 uint32_t loopControl = operands[2];
2253 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2257 "a block cannot have more than one OpLoopMerge instruction");
2265 return emitError(unknownLoc,
"OpPhi must appear in a block");
2268 if (operands.size() < 4) {
2269 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2270 "and variable-parent pairs");
2275 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2276 valueMap[operands[1]] = blockArg;
2277 LLVM_DEBUG(logger.startLine()
2278 <<
"[phi] created block argument " << blockArg
2279 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2283 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2284 uint32_t value = operands[i];
2286 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2287 blockPhiInfo[predecessorTargetPair].push_back(value);
2288 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2289 <<
" with arg id = " << value <<
"\n");
2297 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2299 if (operands.size() < 2)
2300 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2301 "a default target");
2303 if (operands.size() % 2)
2305 "OpSwitch must at have an even number of operands: "
2306 "selector, default target and any number of literal and "
2307 "label <id> pairs");
2315 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2316 literals.push_back(operands[i]);
2321 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2330class ControlFlowStructurizer {
2333 ControlFlowStructurizer(
Location loc, uint32_t control,
2336 llvm::ScopedPrinter &logger)
2337 : location(loc), control(control), blockMergeInfo(mergeInfo),
2338 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2341 ControlFlowStructurizer(
Location loc, uint32_t control,
2344 : location(loc), control(control), blockMergeInfo(mergeInfo),
2345 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2355 LogicalResult structurize();
2360 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2363 spirv::LoopOp createLoopOp(uint32_t loopControl);
2366 void collectBlocksInConstruct();
2375 Block *continueBlock;
2381 llvm::ScopedPrinter &logger;
2387ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2390 OpBuilder builder(&mergeBlock->front());
2392 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2393 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2394 selectionOp.addMergeBlock(builder);
2399spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2402 OpBuilder builder(&mergeBlock->front());
2404 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2405 auto loopOp = spirv::LoopOp::create(builder, location, control);
2406 loopOp.addEntryAndMergeBlock(builder);
2411void ControlFlowStructurizer::collectBlocksInConstruct() {
2412 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2415 constructBlocks.insert(headerBlock);
2419 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2420 for (
auto *successor : constructBlocks[i]->getSuccessors())
2421 if (successor != mergeBlock)
2422 constructBlocks.insert(successor);
2426LogicalResult ControlFlowStructurizer::structurize() {
2427 Operation *op =
nullptr;
2428 bool isLoop = continueBlock !=
nullptr;
2430 if (
auto loopOp = createLoopOp(control))
2431 op = loopOp.getOperation();
2433 if (
auto selectionOp = createSelectionOp(control))
2434 op = selectionOp.getOperation();
2443 mapper.
map(mergeBlock, &body.
back());
2445 collectBlocksInConstruct();
2466 OpBuilder builder(body);
2467 for (
auto *block : constructBlocks) {
2470 auto *newBlock = builder.createBlock(&body.
back());
2471 mapper.
map(block, newBlock);
2472 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2473 <<
" from block " << block <<
"\n");
2475 for (BlockArgument blockArg : block->getArguments()) {
2477 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2478 mapper.
map(blockArg, newArg);
2479 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2480 << blockArg <<
" to " << newArg <<
"\n");
2483 LLVM_DEBUG(logger.startLine()
2484 <<
"[cf] block " << block <<
" is a function entry block\n");
2487 for (
auto &op : *block)
2488 newBlock->push_back(op.
clone(mapper));
2492 auto remapOperands = [&](Operation *op) {
2494 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2495 operand.set(mappedOp);
2498 succOp.set(mappedOp);
2500 for (
auto &block : body)
2501 block.walk(remapOperands);
2509 headerBlock->replaceAllUsesWith(mergeBlock);
2512 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2513 headerBlock->getParentOp()->print(logger.getOStream());
2514 logger.startLine() <<
"\n";
2518 if (!mergeBlock->args_empty()) {
2519 return mergeBlock->getParentOp()->emitError(
2520 "OpPhi in loop merge block unsupported");
2526 for (BlockArgument blockArg : headerBlock->getArguments())
2527 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2531 SmallVector<Value, 4> blockArgs;
2532 if (!headerBlock->args_empty())
2533 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2537 builder.setInsertionPointToEnd(&body.front());
2538 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2539 ArrayRef<Value>(blockArgs));
2544 SmallVector<Value> valuesToYield;
2547 SmallVector<Value> outsideUses;
2561 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2566 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2567 valuesToYield.push_back(body.back().getArguments().back());
2568 outsideUses.push_back(blockArg);
2573 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2576 for (
auto *block : constructBlocks)
2577 block->dropAllReferences();
2582 for (
Block *block : constructBlocks) {
2583 for (Operation &op : *block) {
2587 outsideUses.push_back(
result);
2590 for (BlockArgument &arg : block->getArguments()) {
2591 if (!arg.use_empty()) {
2593 outsideUses.push_back(arg);
2598 assert(valuesToYield.size() == outsideUses.size());
2602 if (!valuesToYield.empty()) {
2603 LLVM_DEBUG(logger.startLine()
2604 <<
"[cf] yielding values from the selection / loop region\n");
2607 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2608 Operation *merge = llvm::getSingleElement(mergeOps);
2610 merge->setOperands(valuesToYield);
2618 builder.setInsertionPoint(&mergeBlock->front());
2620 Operation *newOp =
nullptr;
2623 newOp = spirv::LoopOp::create(builder, location,
2625 static_cast<spirv::LoopControl
>(control));
2627 newOp = spirv::SelectionOp::create(
2629 static_cast<spirv::SelectionControl
>(control));
2639 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2640 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2646 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2653 for (
auto *block : constructBlocks) {
2654 if (!block->use_empty())
2655 return emitError(block->getParent()->getLoc(),
2656 "failed control flow structurization: "
2657 "block has uses outside of the "
2658 "enclosing selection/loop construct");
2659 for (Operation &op : *block)
2661 return op.
emitOpError(
"failed control flow structurization: value has "
2662 "uses outside of the "
2663 "enclosing selection/loop construct");
2664 for (BlockArgument &arg : block->getArguments())
2665 if (!arg.use_empty())
2666 return emitError(arg.getLoc(),
"failed control flow structurization: "
2667 "block argument has uses outside of the "
2668 "enclosing selection/loop construct");
2672 for (
auto *block : constructBlocks) {
2712 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2713 auto it = blockMergeInfo.find(block);
2714 if (it != blockMergeInfo.end()) {
2716 Location loc = it->second.loc;
2720 return emitError(loc,
"failed control flow structurization: nested "
2721 "loop header block should be remapped!");
2723 Block *newContinue = it->second.continueBlock;
2727 return emitError(loc,
"failed control flow structurization: nested "
2728 "loop continue block should be remapped!");
2731 Block *newMerge = it->second.mergeBlock;
2733 newMerge = mappedTo;
2737 blockMergeInfo.
erase(it);
2738 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2745 if (block->walk(updateMergeInfo).wasInterrupted())
2753 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2754 <<
" to only contain a spirv.Branch op\n");
2758 builder.setInsertionPointToEnd(block);
2759 spirv::BranchOp::create(builder, location, mergeBlock);
2761 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2766 LLVM_DEBUG(logger.startLine()
2767 <<
"[cf] after structurizing construct with header block "
2768 << headerBlock <<
":\n"
2777 <<
"//----- [phi] start wiring up block arguments -----//\n";
2783 for (
const auto &info : blockPhiInfo) {
2784 Block *block = info.first.first;
2788 logger.startLine() <<
"[phi] block " << block <<
"\n";
2789 logger.startLine() <<
"[phi] before creating block argument:\n";
2791 logger.startLine() <<
"\n";
2797 opBuilder.setInsertionPoint(op);
2800 blockArgs.reserve(phiInfo.size());
2801 for (uint32_t valueId : phiInfo) {
2803 blockArgs.push_back(value);
2804 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2805 <<
" id = " << valueId <<
"\n");
2807 return emitError(unknownLoc,
"OpPhi references undefined value!");
2811 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2813 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2814 branchOp.getTarget(), blockArgs);
2816 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2817 assert((branchCondOp.getTrueBlock() ==
target ||
2818 branchCondOp.getFalseBlock() ==
target) &&
2819 "expected target to be either the true or false target");
2820 if (
target == branchCondOp.getTrueTarget())
2821 spirv::BranchConditionalOp::create(
2822 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2823 blockArgs, branchCondOp.getFalseBlockArguments(),
2824 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2825 branchCondOp.getFalseTarget());
2827 spirv::BranchConditionalOp::create(
2828 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2829 branchCondOp.getTrueBlockArguments(), blockArgs,
2830 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2831 branchCondOp.getFalseBlock());
2833 branchCondOp.erase();
2835 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2839 logger.startLine() <<
"[phi] after creating block argument:\n";
2841 logger.startLine() <<
"\n";
2844 blockPhiInfo.clear();
2849 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2857 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2859 auto &[block, mergeInfo] = *it;
2862 if (mergeInfo.continueBlock)
2865 if (!block->mightHaveTerminator())
2868 Operation *terminator = block->getTerminator();
2871 if (!isa<spirv::BranchConditionalOp>(terminator))
2875 bool splitHeaderMergeBlock =
false;
2876 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2877 if (mergeInfo.mergeBlock == block)
2878 splitHeaderMergeBlock =
true;
2885 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2888 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2892 blockMergeInfo.erase(block);
2893 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2901 if (!options.enableControlFlowStructurization) {
2905 <<
"//----- [cf] skip structurizing control flow -----//\n";
2913 <<
"//----- [cf] start structurizing control flow -----//\n";
2918 logger.startLine() <<
"[cf] split conditional blocks\n";
2919 logger.startLine() <<
"\n";
2929 while (!blockMergeInfo.empty()) {
2930 Block *headerBlock = blockMergeInfo.
begin()->first;
2934 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2935 headerBlock->
print(logger.getOStream());
2936 logger.startLine() <<
"\n";
2940 assert(mergeBlock &&
"merge block cannot be nullptr");
2942 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2944 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2945 mergeBlock->print(logger.getOStream());
2946 logger.startLine() <<
"\n";
2950 LLVM_DEBUG(
if (continueBlock) {
2951 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2952 continueBlock->print(logger.getOStream());
2953 logger.startLine() <<
"\n";
2957 blockMergeInfo.
erase(blockMergeInfo.begin());
2958 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2959 blockMergeInfo, headerBlock,
2960 mergeBlock, continueBlock
2966 if (failed(structurizer.structurize()))
2973 <<
"//--- [cf] completed structurizing control flow ---//\n";
2986 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
2987 if (fileName.empty())
2988 fileName =
"<unknown>";
3000 if (operands.size() != 3)
3001 return emitError(unknownLoc,
"OpLine must have 3 operands");
3002 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3010 if (operands.size() < 2)
3011 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3013 if (!debugInfoMap.lookup(operands[0]).empty())
3015 "duplicate debug string found for result <id> ")
3018 unsigned wordIndex = 1;
3020 if (wordIndex != operands.size())
3022 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult splitConditionalBlocks()
Move a conditional branch into a separate basic block to avoid unnecessary sinking of defs that may b...
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
DenseMap< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.