23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
35#define DEBUG_TYPE "spirv-deserialization"
44 isa_and_nonnull<spirv::FuncOp>(block->
getParentOp());
54 : binary(binary), context(context), unknownLoc(UnknownLoc::
get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()),
options(
options)
63LogicalResult spirv::Deserializer::deserialize() {
67 <<
"//+++---------- start deserialization ----------+++//\n";
70 if (
failed(processHeader()))
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
89 for (
auto &deferred : deferredInstructions) {
97 LLVM_DEBUG(logger.startLine()
98 <<
"//+++-------- completed deserialization --------+++//\n");
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
117LogicalResult spirv::Deserializer::processHeader() {
120 "SPIR-V binary module must have a 5-word header");
123 return emitError(unknownLoc,
"incorrect magic number");
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
132 version = spirv::Version::V_1_##v; \
142#undef MIN_VERSION_CASE
144 return emitError(unknownLoc,
"unsupported SPIR-V minor version: ")
148 return emitError(unknownLoc,
"unsupported SPIR-V major version: ")
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc,
"OpCapability must have one parameter");
162 auto cap = spirv::symbolizeCapability(operands[0]);
164 return emitError(unknownLoc,
"unknown capability: ") << operands[0];
166 capabilities.insert(*cap);
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
174 "OpExtension must have a literal string for the extension name");
177 unsigned wordIndex = 0;
179 if (wordIndex != words.size())
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
184 return emitError(unknownLoc,
"unknown extension: ") << extName;
186 extensions.insert(*ext);
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
198 unsigned wordIndex = 1;
200 if (wordIndex != words.size()) {
202 "unexpected trailing words in OpExtInstImport");
207void spirv::Deserializer::attachVCETriple() {
209 spirv::ModuleOp::getVCETripleAttrName(),
211 extensions.getArrayRef(), context));
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc,
"OpMemoryModel must have two operands");
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel
>(operands.front())));
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel
>(operands.back())));
231template <
typename AttrTy,
typename EnumAttrTy,
typename EnumTy>
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc,
"OpDecoration with ")
238 << decorationName <<
"needs a cache control integer literal and a "
239 << cacheControlKind <<
" cache control literal";
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr =
static_cast<EnumTy
>(words[3]);
243 auto value = opBuilder.
getAttr<AttrTy>(cacheLevel, cacheControlAttr);
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].
get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.
getArrayAttr(attrs));
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
257 if (words.size() < 2) {
259 unknownLoc,
"OpDecorate must have at least result <id> and Decoration");
261 auto decorationName =
262 stringifyDecoration(
static_cast<spirv::Decoration
>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc,
"invalid Decoration code : ") << words[1];
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (
static_cast<spirv::Decoration
>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc,
"OpDecorate with ")
271 << decorationName <<
" needs a single integer literal";
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode
>(words[2])));
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc,
"OpDecorate with ")
280 << decorationName <<
" needs a single integer literal";
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode
>(words[2])));
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc,
"OpDecorate with ")
290 << decorationName <<
" needs a single integer literal";
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc,
"OpDecorate with ")
298 << decorationName <<
" needs a single integer literal";
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(
static_cast<spirv::BuiltIn
>(words[2]))));
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc,
"OpDecorate with ")
307 << decorationName <<
" needs a single integer literal";
309 typeDecorations[words[0]] = words[2];
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc,
"OpDecorate with ")
315 <<
" needs at least 1 string and 1 integer literal";
323 unsigned wordIndex = 2;
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType
>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 StringAttr::get(context, linkageName), linkageTypeAttr);
329 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 case spirv::Decoration::Coherent:
350 if (words.size() != 2) {
351 return emitError(unknownLoc,
"OpDecoration with ")
352 << decorationName <<
"needs a single target <id>";
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 case spirv::Decoration::Index:
359 if (words.size() != 3) {
360 return emitError(unknownLoc,
"OpDecoration with ")
361 << decorationName <<
"needs a single integer literal";
363 decorations[words[0]].set(
364 symbol, opBuilder.getI32IntegerAttr(
static_cast<int32_t
>(words[2])));
366 case spirv::Decoration::CacheControlLoadINTEL: {
368 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
369 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
375 case spirv::Decoration::CacheControlStoreINTEL: {
377 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
378 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
385 return emitError(unknownLoc,
"unhandled Decoration : '") << decorationName;
391spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
393 if (words.size() < 3) {
395 "OpMemberDecorate must have at least 3 operands");
398 auto decoration =
static_cast<spirv::Decoration
>(words[2]);
399 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
401 " missing offset specification in OpMemberDecorate with "
402 "Offset decoration");
404 ArrayRef<uint32_t> decorationOperands;
405 if (words.size() > 3) {
406 decorationOperands = words.slice(3);
408 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
412LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
413 if (words.size() < 3) {
414 return emitError(unknownLoc,
"OpMemberName must have at least 3 operands");
416 unsigned wordIndex = 2;
418 if (wordIndex != words.size()) {
420 "unexpected trailing words in OpMemberName instruction");
422 memberNameMap[words[0]][words[1]] = name;
428 if (!decorations.contains(argID)) {
429 argAttrs[argIndex] = DictionaryAttr::get(context, {});
433 spirv::DecorationAttr foundDecorationAttr;
435 for (
auto decoration :
436 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
437 spirv::Decoration::AliasedPointer,
438 spirv::Decoration::RestrictPointer}) {
440 if (decAttr.getName() !=
444 if (foundDecorationAttr)
446 "more than one Aliased/Restrict decorations for "
447 "function argument with result <id> ")
450 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
455 spirv::Decoration::RelaxedPrecision))) {
460 if (foundDecorationAttr)
461 return emitError(unknownLoc,
"already found a decoration for function "
462 "argument with result <id> ")
465 foundDecorationAttr = spirv::DecorationAttr::get(
466 context, spirv::Decoration::RelaxedPrecision);
470 if (!foundDecorationAttr)
471 return emitError(unknownLoc,
"unimplemented decoration support for "
472 "function argument with result <id> ")
475 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
476 foundDecorationAttr);
477 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
484 return emitError(unknownLoc,
"found function inside function");
488 if (operands.size() != 4) {
489 return emitError(unknownLoc,
"OpFunction must have 4 parameters");
493 return emitError(unknownLoc,
"undefined result type from <id> ")
497 uint32_t fnID = operands[1];
498 if (funcMap.count(fnID)) {
499 return emitError(unknownLoc,
"duplicate function definition/declaration");
502 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
504 return emitError(unknownLoc,
"unknown Function Control: ") << operands[2];
508 if (!fnType || !isa<FunctionType>(fnType)) {
509 return emitError(unknownLoc,
"unknown function type from <id> ")
512 auto functionType = cast<FunctionType>(fnType);
514 if ((
isVoidType(resultType) && functionType.getNumResults() != 0) ||
515 (functionType.getNumResults() == 1 &&
516 functionType.getResult(0) != resultType)) {
517 return emitError(unknownLoc,
"mismatch in function type ")
518 << functionType <<
" and return type " << resultType <<
" specified";
522 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
523 functionType, fnControl.value());
525 if (decorations.count(fnID)) {
526 for (
auto attr : decorations[fnID].getAttrs()) {
527 funcOp->setAttr(attr.getName(), attr.getValue());
530 curFunction = funcMap[fnID] = funcOp;
531 auto *entryBlock = funcOp.addEntryBlock();
534 <<
"//===-------------------------------------------===//\n";
535 logger.startLine() <<
"[fn] name: " << fnName <<
"\n";
536 logger.startLine() <<
"[fn] type: " << fnType <<
"\n";
537 logger.startLine() <<
"[fn] ID: " << fnID <<
"\n";
538 logger.startLine() <<
"[fn] entry block: " << entryBlock <<
"\n";
543 argAttrs.resize(functionType.getNumInputs());
546 if (functionType.getNumInputs()) {
547 for (
size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
548 auto argType = functionType.getInput(i);
549 spirv::Opcode opcode = spirv::Opcode::OpNop;
552 spirv::Opcode::OpFunctionParameter))) {
555 if (opcode != spirv::Opcode::OpFunctionParameter) {
558 "missing OpFunctionParameter instruction for argument ")
561 if (operands.size() != 2) {
564 "expected result type and result <id> for OpFunctionParameter");
566 auto argDefinedType =
getType(operands[0]);
567 if (!argDefinedType || argDefinedType != argType) {
569 "mismatch in argument type between function type "
571 << functionType <<
" and argument type definition "
572 << argDefinedType <<
" at argument " << i;
575 return emitError(unknownLoc,
"duplicate definition of result <id> ")
582 auto argValue = funcOp.getArgument(i);
583 valueMap[operands[1]] = argValue;
587 if (llvm::any_of(argAttrs, [](
Attribute attr) {
588 auto argAttr = cast<DictionaryAttr>(attr);
589 return !argAttr.empty();
591 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
596 auto linkageAttr = funcOp.getLinkageAttributes();
597 auto hasImportLinkage =
598 linkageAttr && (linkageAttr.value().getLinkageType().
getValue() ==
599 spirv::LinkageType::Import);
600 if (hasImportLinkage)
607 spirv::Opcode opcode = spirv::Opcode::OpNop;
616 spirv::Opcode::OpFunctionEnd))) {
619 if (opcode == spirv::Opcode::OpFunctionEnd) {
622 if (opcode != spirv::Opcode::OpLabel) {
623 return emitError(unknownLoc,
"a basic block must start with OpLabel");
625 if (instOperands.size() != 1) {
626 return emitError(unknownLoc,
"OpLabel should only have result <id>");
628 blockMap[instOperands[0]] = entryBlock;
636 spirv::Opcode::OpFunctionEnd)) &&
637 opcode != spirv::Opcode::OpFunctionEnd) {
642 if (opcode != spirv::Opcode::OpFunctionEnd) {
652 if (!operands.empty()) {
653 return emitError(unknownLoc,
"unexpected operands for OpFunctionEnd");
664 curFunction = std::nullopt;
669 <<
"//===-------------------------------------------===//\n";
676 if (operands.size() < 2) {
678 "missing graph defintion in OpGraphEntryPointARM");
681 unsigned wordIndex = 0;
682 uint32_t graphID = operands[wordIndex++];
683 if (!graphMap.contains(graphID)) {
685 "missing graph definition/declaration with id ")
689 spirv::GraphARMOp graphARM = graphMap[graphID];
691 graphARM.setSymName(name);
692 graphARM.setEntryPoint(
true);
695 for (
int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
697 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
699 return emitError(unknownLoc,
"undefined result <id> ")
700 << operands[wordIndex] <<
" while decoding OpGraphEntryPoint";
706 opBuilder.setInsertionPoint(graphARM);
707 spirv::GraphEntryPointARMOp::create(
708 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
709 opBuilder.getArrayAttr(interface));
717 return emitError(unknownLoc,
"found graph inside graph");
720 if (operands.size() < 2) {
721 return emitError(unknownLoc,
"OpGraphARM must have at least 2 parameters");
725 if (!type || !isa<GraphType>(type)) {
726 return emitError(unknownLoc,
"unknown graph type from <id> ")
729 auto graphType = cast<GraphType>(type);
730 if (graphType.getNumResults() <= 0) {
731 return emitError(unknownLoc,
"expected at least one result");
734 uint32_t graphID = operands[1];
735 if (graphMap.count(graphID)) {
736 return emitError(unknownLoc,
"duplicate graph definition/declaration");
741 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
742 curGraph = graphMap[graphID] = graphOp;
743 Block *entryBlock = graphOp.addEntryBlock();
746 <<
"//===-------------------------------------------===//\n";
747 logger.startLine() <<
"[graph] name: " << graphName <<
"\n";
748 logger.startLine() <<
"[graph] type: " << graphType <<
"\n";
749 logger.startLine() <<
"[graph] ID: " << graphID <<
"\n";
750 logger.startLine() <<
"[graph] entry block: " << entryBlock <<
"\n";
755 for (
auto [
index, argType] : llvm::enumerate(graphType.getInputs())) {
756 spirv::Opcode opcode;
759 spirv::Opcode::OpGraphInputARM))) {
762 if (operands.size() != 3) {
763 return emitError(unknownLoc,
"expected result type, result <id> and "
764 "input index for OpGraphInputARM");
768 if (!argDefinedType) {
769 return emitError(unknownLoc,
"unknown operand type <id> ") << operands[0];
772 if (argDefinedType != argType) {
774 "mismatch in argument type between graph type "
776 << graphType <<
" and argument type definition " << argDefinedType
777 <<
" at argument " <<
index;
780 return emitError(unknownLoc,
"duplicate definition of result <id> ")
785 if (!inputIndexAttr) {
787 "unable to read inputIndex value from constant op ")
790 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
791 valueMap[operands[1]] = argValue;
794 graphOutputs.resize(graphType.getNumResults());
800 blockMap[graphID] = entryBlock;
807 spirv::Opcode opcode;
817 }
while (opcode != spirv::Opcode::OpGraphEndARM);
824 if (operands.size() != 2) {
827 "expected value id and output index for OpGraphSetOutputARM");
830 uint32_t
id = operands[0];
833 return emitError(unknownLoc,
"could not find result <id> ") << id;
837 if (!outputIndexAttr) {
839 "unable to read outputIndex value from constant op ")
842 graphOutputs[outputIndexAttr.getInt()] = value;
849 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
852 if (!operands.empty()) {
853 return emitError(unknownLoc,
"unexpected operands for OpGraphEndARM");
857 curGraph = std::nullopt;
858 graphOutputs.clear();
863 <<
"//===-------------------------------------------===//\n";
868std::optional<std::pair<Attribute, Type>>
870 auto constIt = constantMap.find(
id);
871 if (constIt == constantMap.end())
873 return constIt->getSecond();
876std::optional<std::pair<Attribute, Type>>
878 if (
auto it = constantCompositeReplicateMap.find(
id);
879 it != constantCompositeReplicateMap.end())
884std::optional<spirv::SpecConstOperationMaterializationInfo>
886 auto constIt = specConstOperationMap.find(
id);
887 if (constIt == specConstOperationMap.end())
889 return constIt->getSecond();
893 auto funcName = nameMap.lookup(
id).str();
894 if (funcName.empty()) {
895 funcName =
"spirv_fn_" + std::to_string(
id);
901 std::string graphName = nameMap.lookup(
id).str();
902 if (graphName.empty()) {
903 graphName =
"spirv_graph_" + std::to_string(
id);
909 auto constName = nameMap.lookup(
id).str();
910 if (constName.empty()) {
911 constName =
"spirv_spec_const_" + std::to_string(
id);
918 TypedAttr defaultValue) {
920 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
922 if (decorations.count(resultID)) {
923 for (
auto attr : decorations[resultID].getAttrs())
924 op->setAttr(attr.getName(), attr.getValue());
926 specConstMap[resultID] = op;
930std::optional<spirv::GraphConstantARMOpMaterializationInfo>
932 auto graphConstIt = graphConstantMap.find(
id);
933 if (graphConstIt == graphConstantMap.end())
935 return graphConstIt->getSecond();
940 unsigned wordIndex = 0;
941 if (operands.size() < 3) {
944 "OpVariable needs at least 3 operands, type, <id> and storage class");
948 auto type =
getType(operands[wordIndex]);
950 return emitError(unknownLoc,
"unknown result type <id> : ")
951 << operands[wordIndex];
953 auto ptrType = dyn_cast<spirv::PointerType>(type);
956 "expected a result type <id> to be a spirv.ptr, found : ")
962 auto variableID = operands[wordIndex];
963 auto variableName = nameMap.lookup(variableID).str();
964 if (variableName.empty()) {
965 variableName =
"spirv_var_" + std::to_string(variableID);
970 auto storageClass =
static_cast<spirv::StorageClass
>(operands[wordIndex]);
971 if (ptrType.getStorageClass() != storageClass) {
972 return emitError(unknownLoc,
"mismatch in storage class of pointer type ")
973 << type <<
" and that specified in OpVariable instruction : "
974 << stringifyStorageClass(storageClass);
981 if (wordIndex < operands.size()) {
991 return emitError(unknownLoc,
"unknown <id> ")
992 << operands[wordIndex] <<
"used as initializer";
994 initializer = SymbolRefAttr::get(op);
997 if (wordIndex != operands.size()) {
999 "found more operands than expected when deserializing "
1000 "OpVariable instruction, only ")
1001 << wordIndex <<
" of " << operands.size() <<
" processed";
1004 auto varOp = spirv::GlobalVariableOp::create(
1005 opBuilder, loc, TypeAttr::get(type),
1006 opBuilder.getStringAttr(variableName), initializer);
1009 if (decorations.count(variableID)) {
1010 for (
auto attr : decorations[variableID].getAttrs())
1011 varOp->setAttr(attr.getName(), attr.getValue());
1013 globalVariableMap[variableID] = varOp;
1022 return dyn_cast<IntegerAttr>(constInfo->first);
1026 if (operands.size() < 2) {
1027 return emitError(unknownLoc,
"OpName needs at least 2 operands");
1029 if (!nameMap.lookup(operands[0]).empty()) {
1030 return emitError(unknownLoc,
"duplicate name found for result <id> ")
1033 unsigned wordIndex = 1;
1035 if (wordIndex != operands.size()) {
1037 "unexpected trailing words in OpName instruction");
1039 nameMap[operands[0]] = name;
1049 if (operands.empty()) {
1050 return emitError(unknownLoc,
"type instruction with opcode ")
1051 << spirv::stringifyOpcode(opcode) <<
" needs at least one <id>";
1056 if (typeMap.count(operands[0])) {
1057 return emitError(unknownLoc,
"duplicate definition for result <id> ")
1062 case spirv::Opcode::OpTypeVoid:
1063 if (operands.size() != 1)
1064 return emitError(unknownLoc,
"OpTypeVoid must have no parameters");
1065 typeMap[operands[0]] = opBuilder.getNoneType();
1067 case spirv::Opcode::OpTypeBool:
1068 if (operands.size() != 1)
1069 return emitError(unknownLoc,
"OpTypeBool must have no parameters");
1070 typeMap[operands[0]] = opBuilder.getI1Type();
1072 case spirv::Opcode::OpTypeInt: {
1073 if (operands.size() != 3)
1075 unknownLoc,
"OpTypeInt must have bitwidth and signedness parameters");
1084 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1085 : IntegerType::SignednessSemantics::Signless;
1086 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1088 case spirv::Opcode::OpTypeFloat: {
1089 if (operands.size() != 2 && operands.size() != 3)
1091 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1092 "or 3 operands (type, bitwidth, encoding), but got ")
1094 uint32_t bitWidth = operands[1];
1099 floatTy = opBuilder.getF16Type();
1102 floatTy = opBuilder.getF32Type();
1105 floatTy = opBuilder.getF64Type();
1108 return emitError(unknownLoc,
"unsupported OpTypeFloat bitwidth: ")
1112 if (operands.size() == 3) {
1113 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1114 return emitError(unknownLoc,
"unsupported OpTypeFloat FP encoding: ")
1118 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1119 << bitWidth <<
" (expected 16)";
1120 floatTy = opBuilder.getBF16Type();
1123 typeMap[operands[0]] = floatTy;
1125 case spirv::Opcode::OpTypeVector: {
1126 if (operands.size() != 3) {
1129 "OpTypeVector must have element type and count parameters");
1133 return emitError(unknownLoc,
"OpTypeVector references undefined <id> ")
1136 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1138 case spirv::Opcode::OpTypePointer: {
1141 case spirv::Opcode::OpTypeArray:
1143 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1145 case spirv::Opcode::OpTypeFunction:
1147 case spirv::Opcode::OpTypeImage:
1149 case spirv::Opcode::OpTypeSampledImage:
1151 case spirv::Opcode::OpTypeRuntimeArray:
1153 case spirv::Opcode::OpTypeStruct:
1155 case spirv::Opcode::OpTypeMatrix:
1157 case spirv::Opcode::OpTypeTensorARM:
1159 case spirv::Opcode::OpTypeGraphARM:
1162 return emitError(unknownLoc,
"unhandled type instruction");
1169 if (operands.size() != 3)
1170 return emitError(unknownLoc,
"OpTypePointer must have two parameters");
1172 auto pointeeType =
getType(operands[2]);
1174 return emitError(unknownLoc,
"unknown OpTypePointer pointee type <id> ")
1177 uint32_t typePointerID = operands[0];
1178 auto storageClass =
static_cast<spirv::StorageClass
>(operands[1]);
1181 for (
auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1182 deferredStructIt != std::end(deferredStructTypesInfos);) {
1183 for (
auto *unresolvedMemberIt =
1184 std::begin(deferredStructIt->unresolvedMemberTypes);
1185 unresolvedMemberIt !=
1186 std::end(deferredStructIt->unresolvedMemberTypes);) {
1187 if (unresolvedMemberIt->first == typePointerID) {
1191 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1192 typeMap[typePointerID];
1193 unresolvedMemberIt =
1194 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1196 ++unresolvedMemberIt;
1200 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1202 auto structType = deferredStructIt->deferredStructType;
1204 assert(structType &&
"expected a spirv::StructType");
1205 assert(structType.isIdentified() &&
"expected an indentified struct");
1207 if (failed(structType.trySetBody(
1208 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1209 deferredStructIt->memberDecorationsInfo,
1210 deferredStructIt->structDecorationsInfo)))
1213 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1224 if (operands.size() != 3) {
1226 "OpTypeArray must have element type and count parameters");
1231 return emitError(unknownLoc,
"OpTypeArray references undefined <id> ")
1239 return emitError(unknownLoc,
"OpTypeArray count <id> ")
1240 << operands[2] <<
"can only come from normal constant right now";
1243 if (
auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1244 count = intVal.getValue().getZExtValue();
1246 return emitError(unknownLoc,
"OpTypeArray count must come from a "
1247 "scalar integer constant instruction");
1251 elementTy, count, typeDecorations.lookup(operands[0]));
1257 assert(!operands.empty() &&
"No operands for processing function type");
1258 if (operands.size() == 1) {
1259 return emitError(unknownLoc,
"missing return type for OpTypeFunction");
1261 auto returnType =
getType(operands[1]);
1263 return emitError(unknownLoc,
"unknown return type in OpTypeFunction");
1266 for (
size_t i = 2, e = operands.size(); i < e; ++i) {
1267 auto ty =
getType(operands[i]);
1269 return emitError(unknownLoc,
"unknown argument type in OpTypeFunction");
1271 argTypes.push_back(ty);
1277 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1283 if (operands.size() != 6) {
1285 "OpTypeCooperativeMatrixKHR must have element type, "
1286 "scope, row and column parameters, and use");
1292 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1296 std::optional<spirv::Scope> scope =
1301 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1310 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Rows` references "
1311 "undefined constant <id> ")
1315 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Columns` "
1316 "references undefined constant <id> ")
1320 return emitError(unknownLoc,
"OpTypeCooperativeMatrixKHR `Use` references "
1321 "undefined constant <id> ")
1324 unsigned rows = rowsAttr.getInt();
1325 unsigned columns = columnsAttr.getInt();
1327 std::optional<spirv::CooperativeMatrixUseKHR> use =
1328 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1332 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1336 typeMap[operands[0]] =
1343 if (operands.size() != 2) {
1344 return emitError(unknownLoc,
"OpTypeRuntimeArray must have two operands");
1349 "OpTypeRuntimeArray references undefined <id> ")
1353 memberType, typeDecorations.lookup(operands[0]));
1361 if (operands.empty()) {
1362 return emitError(unknownLoc,
"OpTypeStruct must have at least result <id>");
1365 if (operands.size() == 1) {
1367 typeMap[operands[0]] =
1376 for (
auto op : llvm::drop_begin(operands, 1)) {
1378 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1380 if (!memberType && !typeForwardPtr)
1381 return emitError(unknownLoc,
"OpTypeStruct references undefined <id> ")
1385 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1387 memberTypes.push_back(memberType);
1392 if (memberDecorationMap.count(operands[0])) {
1393 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1394 for (
auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1395 if (allMemberDecorations.count(memberIndex)) {
1396 for (
auto &memberDecoration : allMemberDecorations[memberIndex]) {
1398 if (memberDecoration.first == spirv::Decoration::Offset) {
1400 if (offsetInfo.empty()) {
1401 offsetInfo.resize(memberTypes.size());
1403 offsetInfo[memberIndex] = memberDecoration.second[0];
1405 auto intType = mlir::IntegerType::get(context, 32);
1406 if (!memberDecoration.second.empty()) {
1407 memberDecorationsInfo.emplace_back(
1408 memberIndex, memberDecoration.first,
1409 IntegerAttr::get(intType, memberDecoration.second[0]));
1411 memberDecorationsInfo.emplace_back(
1412 memberIndex, memberDecoration.first, UnitAttr::get(context));
1421 if (decorations.count(operands[0])) {
1424 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1425 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(),
true));
1426 assert(decoration.has_value());
1427 structDecorationsInfo.emplace_back(decoration.value(),
1428 decorationAttr.getValue());
1432 uint32_t structID = operands[0];
1433 std::string structIdentifier = nameMap.lookup(structID).str();
1435 if (structIdentifier.empty()) {
1436 assert(unresolvedMemberTypes.empty() &&
1437 "didn't expect unresolved member types");
1439 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1442 typeMap[structID] = structTy;
1444 if (!unresolvedMemberTypes.empty())
1445 deferredStructTypesInfos.push_back(
1446 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1447 memberDecorationsInfo, structDecorationsInfo});
1448 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1449 memberDecorationsInfo,
1450 structDecorationsInfo)))
1461 if (operands.size() != 3) {
1463 return emitError(unknownLoc,
"OpTypeMatrix must have 3 operands"
1464 " (result_id, column_type, and column_count)");
1470 "OpTypeMatrix references undefined column type.")
1474 uint32_t colsCount = operands[2];
1481 unsigned size = operands.size();
1482 if (size < 2 || size > 4)
1483 return emitError(unknownLoc,
"OpTypeTensorARM must have 2-4 operands "
1484 "(result_id, element_type, (rank), (shape)) ")
1490 "OpTypeTensorARM references undefined element type ")
1500 return emitError(unknownLoc,
"OpTypeTensorARM rank must come from a "
1501 "scalar integer constant instruction");
1502 unsigned rank = rankAttr.getValue().getZExtValue();
1509 std::optional<std::pair<Attribute, Type>> shapeInfo =
1512 return emitError(unknownLoc,
"OpTypeTensorARM shape must come from a "
1513 "constant instruction of type OpTypeArray");
1515 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1517 for (
auto dimAttr : shapeArrayAttr.getValue()) {
1518 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1520 return emitError(unknownLoc,
"OpTypeTensorARM shape has an invalid "
1522 shape.push_back(dimIntAttr.getValue().getSExtValue());
1530 unsigned size = operands.size();
1532 return emitError(unknownLoc,
"OpTypeGraphARM must have at least 2 operands "
1533 "(result_id, num_inputs, (inout0_type, "
1534 "inout1_type, ...))")
1537 uint32_t numInputs = operands[1];
1540 for (
unsigned i = 2; i < size; ++i) {
1544 "OpTypeGraphARM references undefined element type.")
1547 if (i - 2 >= numInputs) {
1548 returnTypes.push_back(inOutTy);
1550 argTypes.push_back(inOutTy);
1553 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1559 if (operands.size() != 2)
1561 "OpTypeForwardPointer instruction must have two operands");
1563 typeForwardPointerIDs.insert(operands[0]);
1573 if (operands.size() != 8)
1576 "OpTypeImage with non-eight operands are not supported yet");
1580 return emitError(unknownLoc,
"OpTypeImage references undefined <id>: ")
1583 auto dim = spirv::symbolizeDim(operands[2]);
1585 return emitError(unknownLoc,
"unknown Dim for OpTypeImage: ")
1588 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1590 return emitError(unknownLoc,
"unknown Depth for OpTypeImage: ")
1593 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1595 return emitError(unknownLoc,
"unknown Arrayed for OpTypeImage: ")
1598 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1600 return emitError(unknownLoc,
"unknown MS for OpTypeImage: ") << operands[5];
1602 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1603 if (!samplerUseInfo)
1604 return emitError(unknownLoc,
"unknown Sampled for OpTypeImage: ")
1607 auto format = spirv::symbolizeImageFormat(operands[7]);
1609 return emitError(unknownLoc,
"unknown Format for OpTypeImage: ")
1613 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1614 samplingInfo.value(), samplerUseInfo.value(), format.value());
1620 if (operands.size() != 2)
1621 return emitError(unknownLoc,
"OpTypeSampledImage must have two operands");
1626 "OpTypeSampledImage references undefined <id>: ")
1639 StringRef opname = isSpec ?
"OpSpecConstant" :
"OpConstant";
1641 if (operands.size() < 2) {
1643 << opname <<
" must have type <id> and result <id>";
1645 if (operands.size() < 3) {
1647 << opname <<
" must have at least 1 more parameter";
1652 return emitError(unknownLoc,
"undefined result type from <id> ")
1656 auto checkOperandSizeForBitwidth = [&](
unsigned bitwidth) -> LogicalResult {
1657 if (bitwidth == 64) {
1658 if (operands.size() == 4) {
1662 << opname <<
" should have 2 parameters for 64-bit values";
1664 if (bitwidth <= 32) {
1665 if (operands.size() == 3) {
1671 <<
" should have 1 parameter for values with no more than 32 bits";
1673 return emitError(unknownLoc,
"unsupported OpConstant bitwidth: ")
1677 auto resultID = operands[1];
1679 if (
auto intType = dyn_cast<IntegerType>(resultType)) {
1680 auto bitwidth = intType.getWidth();
1681 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1686 if (bitwidth == 64) {
1693 } words = {operands[2], operands[3]};
1694 value = APInt(64, llvm::bit_cast<uint64_t>(words),
true);
1695 }
else if (bitwidth <= 32) {
1696 value = APInt(bitwidth, operands[2],
true,
1700 auto attr = opBuilder.getIntegerAttr(intType, value);
1707 constantMap.try_emplace(resultID, attr, intType);
1713 if (
auto floatType = dyn_cast<FloatType>(resultType)) {
1714 auto bitwidth = floatType.getWidth();
1715 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1720 if (floatType.isF64()) {
1727 } words = {operands[2], operands[3]};
1728 value = APFloat(llvm::bit_cast<double>(words));
1729 }
else if (floatType.isF32()) {
1730 value = APFloat(llvm::bit_cast<float>(operands[2]));
1731 }
else if (floatType.isF16()) {
1732 APInt data(16, operands[2]);
1733 value = APFloat(APFloat::IEEEhalf(), data);
1734 }
else if (floatType.isBF16()) {
1735 APInt data(16, operands[2]);
1736 value = APFloat(APFloat::BFloat(), data);
1739 auto attr = opBuilder.getFloatAttr(floatType, value);
1745 constantMap.try_emplace(resultID, attr, floatType);
1751 return emitError(unknownLoc,
"OpConstant can only generate values of "
1752 "scalar integer or floating-point type");
1757 if (operands.size() != 2) {
1759 << (isSpec ?
"Spec" :
"") <<
"Constant"
1760 << (isTrue ?
"True" :
"False")
1761 <<
" must have type <id> and result <id>";
1764 auto attr = opBuilder.getBoolAttr(isTrue);
1765 auto resultID = operands[1];
1771 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1779 if (operands.size() < 2) {
1781 "OpConstantComposite must have type <id> and result <id>");
1783 if (operands.size() < 3) {
1785 "OpConstantComposite must have at least 1 parameter");
1790 return emitError(unknownLoc,
"undefined result type from <id> ")
1795 elements.reserve(operands.size() - 2);
1796 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1799 return emitError(unknownLoc,
"OpConstantComposite component <id> ")
1800 << operands[i] <<
" must come from a normal constant";
1802 elements.push_back(elementInfo->first);
1805 auto resultID = operands[1];
1806 if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1809 if (
auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1810 for (
auto value : denseElemAttr.getValues<
Attribute>())
1811 flattenedElems.push_back(value);
1813 flattenedElems.push_back(element);
1817 constantMap.try_emplace(resultID, attr, tensorType);
1818 }
else if (
auto shapedType = dyn_cast<ShapedType>(resultType)) {
1822 constantMap.try_emplace(resultID, attr, shapedType);
1823 }
else if (
auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1824 auto attr = opBuilder.getArrayAttr(elements);
1825 constantMap.try_emplace(resultID, attr, resultType);
1827 return emitError(unknownLoc,
"unsupported OpConstantComposite type: ")
1836 if (operands.size() != 3) {
1839 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1845 return emitError(unknownLoc,
"undefined result type from <id> ")
1849 auto compositeType = dyn_cast<CompositeType>(resultType);
1850 if (!compositeType) {
1852 "result type from <id> is not a composite type")
1856 uint32_t resultID = operands[1];
1857 uint32_t constantID = operands[2];
1859 std::optional<std::pair<Attribute, Type>> constantInfo =
1861 if (constantInfo.has_value()) {
1862 constantCompositeReplicateMap.try_emplace(
1863 resultID, constantInfo.value().first, resultType);
1867 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1869 if (replicatedConstantCompositeInfo.has_value()) {
1870 constantCompositeReplicateMap.try_emplace(
1871 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1875 return emitError(unknownLoc,
"OpConstantCompositeReplicateEXT operand <id> ")
1877 <<
" must come from a normal constant or a "
1878 "OpConstantCompositeReplicateEXT";
1883 if (operands.size() < 2) {
1886 "OpSpecConstantComposite must have type <id> and result <id>");
1888 if (operands.size() < 3) {
1890 "OpSpecConstantComposite must have at least 1 parameter");
1895 return emitError(unknownLoc,
"undefined result type from <id> ")
1899 auto resultID = operands[1];
1903 elements.reserve(operands.size() - 2);
1904 for (
unsigned i = 2, e = operands.size(); i < e; ++i) {
1906 elements.push_back(SymbolRefAttr::get(elementInfo));
1909 auto op = spirv::SpecConstantCompositeOp::create(
1910 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1911 opBuilder.getArrayAttr(elements));
1912 specConstCompositeMap[resultID] = op;
1919 if (operands.size() != 3) {
1920 return emitError(unknownLoc,
"OpSpecConstantCompositeReplicateEXT expects "
1921 "3 operands but found ")
1927 return emitError(unknownLoc,
"undefined result type from <id> ")
1931 auto compositeType = dyn_cast<CompositeType>(resultType);
1932 if (!compositeType) {
1934 "result type from <id> is not a composite type")
1938 uint32_t resultID = operands[1];
1941 spirv::SpecConstantOp constituentSpecConstantOp =
1943 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1944 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1945 SymbolRefAttr::get(constituentSpecConstantOp));
1947 specConstCompositeReplicateMap[resultID] = op;
1954 if (operands.size() < 3)
1955 return emitError(unknownLoc,
"OpConstantOperation must have type <id>, "
1956 "result <id>, and operand opcode");
1958 uint32_t resultTypeID = operands[0];
1961 return emitError(unknownLoc,
"undefined result type from <id> ")
1964 uint32_t resultID = operands[1];
1965 spirv::Opcode enclosedOpcode =
static_cast<spirv::Opcode
>(operands[2]);
1966 auto emplaceResult = specConstOperationMap.try_emplace(
1969 enclosedOpcode, resultTypeID,
1972 if (!emplaceResult.second)
1973 return emitError(unknownLoc,
"value with <id>: ")
1974 << resultID <<
" is probably defined before.";
1980 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1996 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1997 constexpr uint32_t fakeID =
static_cast<uint32_t
>(-3);
2000 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2001 enclosedOpResultTypeAndOperands.push_back(fakeID);
2002 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2003 enclosedOpOperands.end());
2018 auto specConstOperationOp =
2019 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2021 Region &body = specConstOperationOp.getBody();
2023 body.
getBlocks().splice(body.
end(), curBlock->getParent()->getBlocks(),
2030 opBuilder.setInsertionPointToEnd(&block);
2032 spirv::YieldOp::create(opBuilder, loc, block.
front().
getResult(0));
2033 return specConstOperationOp.getResult();
2038 if (operands.size() != 2) {
2040 "OpConstantNull must only have type <id> and result <id>");
2045 return emitError(unknownLoc,
"undefined result type from <id> ")
2049 auto resultID = operands[1];
2051 if (resultType.
isIntOrFloat() || isa<VectorType>(resultType)) {
2052 attr = opBuilder.getZeroAttr(resultType);
2053 }
else if (
auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2054 if (
auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2061 constantMap.try_emplace(resultID, attr, resultType);
2065 return emitError(unknownLoc,
"unsupported OpConstantNull type: ")
2071 if (operands.size() < 3) {
2073 <<
"OpGraphConstantARM must have at least 2 operands";
2078 return emitError(unknownLoc,
"undefined result type from <id> ")
2082 uint32_t resultID = operands[1];
2084 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2085 return emitError(unknownLoc,
"result must be of type OpTypeTensorARM");
2088 APInt graph_constant_id = APInt(32, operands[2],
true);
2089 Type i32Ty = opBuilder.getIntegerType(32);
2090 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2091 graphConstantMap.try_emplace(
2103 LLVM_DEBUG(logger.startLine() <<
"[block] got exiting block for id = " <<
id
2104 <<
" @ " << block <<
"\n");
2111 auto *block = curFunction->addBlock();
2112 LLVM_DEBUG(logger.startLine() <<
"[block] created block for id = " <<
id
2113 <<
" @ " << block <<
"\n");
2114 return blockMap[id] = block;
2119 return emitError(unknownLoc,
"OpBranch must appear inside a block");
2122 if (operands.size() != 1) {
2123 return emitError(unknownLoc,
"OpBranch must take exactly one target label");
2131 spirv::BranchOp::create(opBuilder, loc,
target);
2141 "OpBranchConditional must appear inside a block");
2144 if (operands.size() != 3 && operands.size() != 5) {
2146 "OpBranchConditional must have condition, true label, "
2147 "false label, and optionally two branch weights");
2150 auto condition =
getValue(operands[0]);
2154 std::optional<std::pair<uint32_t, uint32_t>> weights;
2155 if (operands.size() == 5) {
2156 weights = std::make_pair(operands[3], operands[4]);
2162 spirv::BranchConditionalOp::create(
2163 opBuilder, loc, condition, trueBlock,
2173 return emitError(unknownLoc,
"OpLabel must appear inside a function");
2176 if (operands.size() != 1) {
2177 return emitError(unknownLoc,
"OpLabel should only have result <id>");
2180 auto labelID = operands[0];
2183 LLVM_DEBUG(logger.startLine()
2184 <<
"[block] populating block " << block <<
"\n");
2186 assert(block->empty() &&
"re-deserialize the same block!");
2188 opBuilder.setInsertionPointToStart(block);
2189 blockMap[labelID] = curBlock = block;
2196 return emitError(unknownLoc,
"a graph block must appear inside a graph");
2201 LLVM_DEBUG(logger.startLine()
2202 <<
"[block] populating block " << block <<
"\n");
2204 assert(block->
empty() &&
"re-deserialize the same block!");
2206 opBuilder.setInsertionPointToStart(block);
2207 blockMap[graphID] = curBlock = block;
2215 return emitError(unknownLoc,
"OpSelectionMerge must appear in a block");
2218 if (operands.size() < 2) {
2221 "OpSelectionMerge must specify merge target and selection control");
2226 auto selectionControl = operands[1];
2228 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2232 "a block cannot have more than one OpSelectionMerge instruction");
2241 return emitError(unknownLoc,
"OpLoopMerge must appear in a block");
2244 if (operands.size() < 3) {
2245 return emitError(unknownLoc,
"OpLoopMerge must specify merge target, "
2246 "continue target and loop control");
2252 uint32_t loopControl = operands[2];
2255 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2259 "a block cannot have more than one OpLoopMerge instruction");
2267 return emitError(unknownLoc,
"OpPhi must appear in a block");
2270 if (operands.size() < 4) {
2271 return emitError(unknownLoc,
"OpPhi must specify result type, result <id>, "
2272 "and variable-parent pairs");
2277 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2278 valueMap[operands[1]] = blockArg;
2279 LLVM_DEBUG(logger.startLine()
2280 <<
"[phi] created block argument " << blockArg
2281 <<
" id = " << operands[1] <<
" of type " << blockArgType <<
"\n");
2285 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2286 uint32_t value = operands[i];
2288 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2289 blockPhiInfo[predecessorTargetPair].push_back(value);
2290 LLVM_DEBUG(logger.startLine() <<
"[phi] predecessor @ " << predecessor
2291 <<
" with arg id = " << value <<
"\n");
2299 return emitError(unknownLoc,
"OpSwitch must appear in a block");
2301 if (operands.size() < 2)
2302 return emitError(unknownLoc,
"OpSwitch must at least specify selector and "
2303 "a default target");
2305 if (operands.size() % 2)
2307 "OpSwitch must at have an even number of operands: "
2308 "selector, default target and any number of literal and "
2309 "label <id> pairs");
2317 for (
unsigned i = 2, e = operands.size(); i < e; i += 2) {
2318 literals.push_back(operands[i]);
2323 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2332class ControlFlowStructurizer {
2335 ControlFlowStructurizer(
Location loc, uint32_t control,
2338 llvm::ScopedPrinter &logger)
2339 : location(loc), control(control), blockMergeInfo(mergeInfo),
2340 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2343 ControlFlowStructurizer(
Location loc, uint32_t control,
2346 : location(loc), control(control), blockMergeInfo(mergeInfo),
2347 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2357 LogicalResult structurize();
2362 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2365 spirv::LoopOp createLoopOp(uint32_t loopControl);
2368 void collectBlocksInConstruct();
2377 Block *continueBlock;
2383 llvm::ScopedPrinter &logger;
2389ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2392 OpBuilder builder(&mergeBlock->front());
2394 auto control =
static_cast<spirv::SelectionControl
>(selectionControl);
2395 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2396 selectionOp.addMergeBlock(builder);
2401spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2404 OpBuilder builder(&mergeBlock->front());
2406 auto control =
static_cast<spirv::LoopControl
>(loopControl);
2407 auto loopOp = spirv::LoopOp::create(builder, location, control);
2408 loopOp.addEntryAndMergeBlock(builder);
2413void ControlFlowStructurizer::collectBlocksInConstruct() {
2414 assert(constructBlocks.empty() &&
"expected empty constructBlocks");
2417 constructBlocks.insert(headerBlock);
2421 for (
unsigned i = 0; i < constructBlocks.size(); ++i) {
2422 for (
auto *successor : constructBlocks[i]->getSuccessors())
2423 if (successor != mergeBlock)
2424 constructBlocks.insert(successor);
2428LogicalResult ControlFlowStructurizer::structurize() {
2429 Operation *op =
nullptr;
2430 bool isLoop = continueBlock !=
nullptr;
2432 if (
auto loopOp = createLoopOp(control))
2433 op = loopOp.getOperation();
2435 if (
auto selectionOp = createSelectionOp(control))
2436 op = selectionOp.getOperation();
2445 mapper.
map(mergeBlock, &body.
back());
2447 collectBlocksInConstruct();
2468 OpBuilder builder(body);
2469 for (
auto *block : constructBlocks) {
2472 auto *newBlock = builder.createBlock(&body.
back());
2473 mapper.
map(block, newBlock);
2474 LLVM_DEBUG(logger.startLine() <<
"[cf] cloned block " << newBlock
2475 <<
" from block " << block <<
"\n");
2477 for (BlockArgument blockArg : block->getArguments()) {
2479 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2480 mapper.
map(blockArg, newArg);
2481 LLVM_DEBUG(logger.startLine() <<
"[cf] remapped block argument "
2482 << blockArg <<
" to " << newArg <<
"\n");
2485 LLVM_DEBUG(logger.startLine()
2486 <<
"[cf] block " << block <<
" is a function entry block\n");
2489 for (
auto &op : *block)
2490 newBlock->push_back(op.
clone(mapper));
2494 auto remapOperands = [&](Operation *op) {
2496 if (Value mappedOp = mapper.
lookupOrNull(operand.get()))
2497 operand.set(mappedOp);
2500 succOp.set(mappedOp);
2502 for (
auto &block : body)
2503 block.walk(remapOperands);
2511 headerBlock->replaceAllUsesWith(mergeBlock);
2514 logger.startLine() <<
"[cf] after cloning and fixing references:\n";
2515 headerBlock->getParentOp()->print(logger.getOStream());
2516 logger.startLine() <<
"\n";
2520 if (!mergeBlock->args_empty()) {
2521 return mergeBlock->getParentOp()->emitError(
2522 "OpPhi in loop merge block unsupported");
2528 for (BlockArgument blockArg : headerBlock->getArguments())
2529 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2533 SmallVector<Value, 4> blockArgs;
2534 if (!headerBlock->args_empty())
2535 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2539 builder.setInsertionPointToEnd(&body.front());
2540 spirv::BranchOp::create(builder, location, mapper.
lookupOrNull(headerBlock),
2541 ArrayRef<Value>(blockArgs));
2546 SmallVector<Value> valuesToYield;
2549 SmallVector<Value> outsideUses;
2563 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2568 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2569 valuesToYield.push_back(body.back().getArguments().back());
2570 outsideUses.push_back(blockArg);
2575 LLVM_DEBUG(logger.startLine() <<
"[cf] cleaning up blocks after clone\n");
2578 for (
auto *block : constructBlocks)
2579 block->dropAllReferences();
2584 for (
Block *block : constructBlocks) {
2585 for (Operation &op : *block) {
2589 outsideUses.push_back(
result);
2592 for (BlockArgument &arg : block->getArguments()) {
2593 if (!arg.use_empty()) {
2595 outsideUses.push_back(arg);
2600 assert(valuesToYield.size() == outsideUses.size());
2604 if (!valuesToYield.empty()) {
2605 LLVM_DEBUG(logger.startLine()
2606 <<
"[cf] yielding values from the selection / loop region\n");
2609 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2610 Operation *merge = llvm::getSingleElement(mergeOps);
2612 merge->setOperands(valuesToYield);
2620 builder.setInsertionPoint(&mergeBlock->front());
2622 Operation *newOp =
nullptr;
2625 newOp = spirv::LoopOp::create(builder, location,
2627 static_cast<spirv::LoopControl
>(control));
2629 newOp = spirv::SelectionOp::create(
2631 static_cast<spirv::SelectionControl
>(control));
2641 for (
unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2642 outsideUses[i].replaceAllUsesWith(op->
getResult(i));
2648 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2655 for (
auto *block : constructBlocks) {
2656 if (!block->use_empty())
2657 return emitError(block->getParent()->getLoc(),
2658 "failed control flow structurization: "
2659 "block has uses outside of the "
2660 "enclosing selection/loop construct");
2661 for (Operation &op : *block)
2663 return op.
emitOpError(
"failed control flow structurization: value has "
2664 "uses outside of the "
2665 "enclosing selection/loop construct");
2666 for (BlockArgument &arg : block->getArguments())
2667 if (!arg.use_empty())
2668 return emitError(arg.getLoc(),
"failed control flow structurization: "
2669 "block argument has uses outside of the "
2670 "enclosing selection/loop construct");
2674 for (
auto *block : constructBlocks) {
2714 auto updateMergeInfo = [&](
Block *block) -> WalkResult {
2715 auto it = blockMergeInfo.find(block);
2716 if (it != blockMergeInfo.end()) {
2718 Location loc = it->second.loc;
2722 return emitError(loc,
"failed control flow structurization: nested "
2723 "loop header block should be remapped!");
2725 Block *newContinue = it->second.continueBlock;
2729 return emitError(loc,
"failed control flow structurization: nested "
2730 "loop continue block should be remapped!");
2733 Block *newMerge = it->second.mergeBlock;
2735 newMerge = mappedTo;
2739 blockMergeInfo.
erase(it);
2740 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2747 if (block->walk(updateMergeInfo).wasInterrupted())
2755 LLVM_DEBUG(logger.startLine() <<
"[cf] changing entry block " << block
2756 <<
" to only contain a spirv.Branch op\n");
2760 builder.setInsertionPointToEnd(block);
2761 spirv::BranchOp::create(builder, location, mergeBlock);
2763 LLVM_DEBUG(logger.startLine() <<
"[cf] erasing block " << block <<
"\n");
2768 LLVM_DEBUG(logger.startLine()
2769 <<
"[cf] after structurizing construct with header block "
2770 << headerBlock <<
":\n"
2779 <<
"//----- [phi] start wiring up block arguments -----//\n";
2785 for (
const auto &info : blockPhiInfo) {
2786 Block *block = info.first.first;
2790 logger.startLine() <<
"[phi] block " << block <<
"\n";
2791 logger.startLine() <<
"[phi] before creating block argument:\n";
2793 logger.startLine() <<
"\n";
2799 opBuilder.setInsertionPoint(op);
2802 blockArgs.reserve(phiInfo.size());
2803 for (uint32_t valueId : phiInfo) {
2805 blockArgs.push_back(value);
2806 LLVM_DEBUG(logger.startLine() <<
"[phi] block argument " << value
2807 <<
" id = " << valueId <<
"\n");
2809 return emitError(unknownLoc,
"OpPhi references undefined value!");
2813 if (
auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2815 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2816 branchOp.getTarget(), blockArgs);
2818 }
else if (
auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2819 assert((branchCondOp.getTrueBlock() ==
target ||
2820 branchCondOp.getFalseBlock() ==
target) &&
2821 "expected target to be either the true or false target");
2822 if (
target == branchCondOp.getTrueTarget())
2823 spirv::BranchConditionalOp::create(
2824 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2825 blockArgs, branchCondOp.getFalseBlockArguments(),
2826 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2827 branchCondOp.getFalseTarget());
2829 spirv::BranchConditionalOp::create(
2830 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2831 branchCondOp.getTrueBlockArguments(), blockArgs,
2832 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2833 branchCondOp.getFalseBlock());
2835 branchCondOp.erase();
2836 }
else if (
auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2837 if (
target == switchOp.getDefaultTarget()) {
2841 spirv::SwitchOp::create(
2842 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2843 switchOp.getDefaultTarget(), blockArgs, literals,
2844 switchOp.getTargets(), targetOperands);
2848 auto it = llvm::find(targets,
target);
2849 assert(it != targets.end());
2850 size_t index = std::distance(targets.begin(), it);
2851 switchOp.getTargetOperandsMutable(
index).assign(blockArgs);
2854 return emitError(unknownLoc,
"unimplemented terminator for Phi creation");
2858 logger.startLine() <<
"[phi] after creating block argument:\n";
2860 logger.startLine() <<
"\n";
2863 blockPhiInfo.clear();
2868 <<
"//--- [phi] completed wiring up block arguments ---//\n";
2876 for (
auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2878 auto &[block, mergeInfo] = *it;
2881 if (mergeInfo.continueBlock)
2884 if (!block->mightHaveTerminator())
2887 Operation *terminator = block->getTerminator();
2890 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2894 bool splitHeaderMergeBlock =
false;
2895 for (
const auto &[_, mergeInfo] : blockMergeInfo) {
2896 if (mergeInfo.mergeBlock == block)
2897 splitHeaderMergeBlock =
true;
2904 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2907 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2911 blockMergeInfo.erase(block);
2912 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2920 if (!options.enableControlFlowStructurization) {
2924 <<
"//----- [cf] skip structurizing control flow -----//\n";
2932 <<
"//----- [cf] start structurizing control flow -----//\n";
2937 logger.startLine() <<
"[cf] split conditional blocks\n";
2938 logger.startLine() <<
"\n";
2945 while (!blockMergeInfo.empty()) {
2946 Block *headerBlock = blockMergeInfo.
begin()->first;
2950 logger.startLine() <<
"[cf] header block " << headerBlock <<
":\n";
2951 headerBlock->
print(logger.getOStream());
2952 logger.startLine() <<
"\n";
2956 assert(mergeBlock &&
"merge block cannot be nullptr");
2958 return emitError(unknownLoc,
"OpPhi in loop merge block unimplemented");
2960 logger.startLine() <<
"[cf] merge block " << mergeBlock <<
":\n";
2961 mergeBlock->print(logger.getOStream());
2962 logger.startLine() <<
"\n";
2966 LLVM_DEBUG(
if (continueBlock) {
2967 logger.startLine() <<
"[cf] continue block " << continueBlock <<
":\n";
2968 continueBlock->print(logger.getOStream());
2969 logger.startLine() <<
"\n";
2973 blockMergeInfo.
erase(blockMergeInfo.begin());
2974 ControlFlowStructurizer structurizer(mergeInfo.
loc, mergeInfo.
control,
2975 blockMergeInfo, headerBlock,
2976 mergeBlock, continueBlock
2982 if (failed(structurizer.structurize()))
2989 <<
"//--- [cf] completed structurizing control flow ---//\n";
3002 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3003 if (fileName.empty())
3004 fileName =
"<unknown>";
3016 if (operands.size() != 3)
3017 return emitError(unknownLoc,
"OpLine must have 3 operands");
3018 debugLine =
DebugLine{operands[0], operands[1], operands[2]};
3026 if (operands.size() < 2)
3027 return emitError(unknownLoc,
"OpString needs at least 2 operands");
3029 if (!debugInfoMap.lookup(operands[0]).empty())
3031 "duplicate debug string found for result <id> ")
3034 unsigned wordIndex = 1;
3036 if (wordIndex != operands.size())
3038 "unexpected trailing words in OpString instruction");
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
void erase()
Unlink this Block from its parent region and delete it.
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Operation * getTerminator()
Get the terminator operation of this block.
void print(raw_ostream &os)
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< BlockOperand > getBlockOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool use_empty()
Returns true if this operation has no uses.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MutableArrayRef< OpOperand > getOpOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockListType & getBlocks()
BlockListType::iterator iterator
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
static WalkResult advance()
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.