MLIR 23.0.0git
Deserializer.cpp
Go to the documentation of this file.
1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(unknownLoc,
120 "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(unknownLoc, "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
142#undef MIN_VERSION_CASE
143 default:
144 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145 << minorVersion;
146 }
147 } else {
148 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149 << majorVersion;
150 }
151
152 // TODO: generator number, bound, schema
153 curOffset = spirv::kHeaderWordCount;
154 return success();
155}
156
157LogicalResult
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc, "OpCapability must have one parameter");
161
162 auto cap = spirv::symbolizeCapability(operands[0]);
163 if (!cap)
164 return emitError(unknownLoc, "unknown capability: ") << operands[0];
165
166 capabilities.insert(*cap);
167 return success();
168}
169
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171 if (words.empty()) {
172 return emitError(
173 unknownLoc,
174 "OpExtension must have a literal string for the extension name");
175 }
176
177 unsigned wordIndex = 0;
178 StringRef extName = decodeStringLiteral(words, wordIndex);
179 if (wordIndex != words.size())
180 return emitError(unknownLoc,
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
183 if (!ext)
184 return emitError(unknownLoc, "unknown extension: ") << extName;
185
186 extensions.insert(*ext);
187 return success();
188}
189
190LogicalResult
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
193 return emitError(unknownLoc,
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
196 }
197
198 unsigned wordIndex = 1;
199 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200 if (wordIndex != words.size()) {
201 return emitError(unknownLoc,
202 "unexpected trailing words in OpExtInstImport");
203 }
204 return success();
205}
206
207void spirv::Deserializer::attachVCETriple() {
208 (*module)->setAttr(
209 spirv::ModuleOp::getVCETripleAttrName(),
210 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211 extensions.getArrayRef(), context));
212}
213
214LogicalResult
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc, "OpMemoryModel must have two operands");
218
219 (*module)->setAttr(
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel>(operands.front())));
223
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel>(operands.back())));
227
228 return success();
229}
230
231template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233 Location loc, OpBuilder &opBuilder,
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc, "OpDecorate with ")
238 << decorationName << " needs a cache control integer literal and a "
239 << cacheControlKind << " cache control literal";
240 }
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 if (auto attrList =
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250 return success();
251}
252
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254 // TODO: This function should also be auto-generated. For now, since only a
255 // few decorations are processed/handled in a meaningful manner, going with a
256 // manual implementation.
257 if (words.size() < 2) {
258 return emitError(
259 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260 }
261 auto decorationName =
262 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265 }
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (static_cast<spirv::Decoration>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc, "OpDecorate with ")
271 << decorationName << " needs a single integer literal";
272 }
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode>(words[2])));
276 break;
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc, "OpDecorate with ")
280 << decorationName << " needs a single integer literal";
281 }
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode>(words[2])));
285 break;
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::Location:
289 case spirv::Decoration::SpecId:
290 case spirv::Decoration::Index:
291 case spirv::Decoration::Offset:
292 case spirv::Decoration::XfbBuffer:
293 case spirv::Decoration::XfbStride:
294 if (words.size() != 3) {
295 return emitError(unknownLoc, "OpDecorate with ")
296 << decorationName << " needs a single integer literal";
297 }
298 decorations[words[0]].set(
299 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
300 break;
301 case spirv::Decoration::BuiltIn:
302 if (words.size() != 3) {
303 return emitError(unknownLoc, "OpDecorate with ")
304 << decorationName << " needs a single integer literal";
305 }
306 decorations[words[0]].set(
307 symbol, opBuilder.getStringAttr(
308 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
309 break;
310 case spirv::Decoration::ArrayStride:
311 if (words.size() != 3) {
312 return emitError(unknownLoc, "OpDecorate with ")
313 << decorationName << " needs a single integer literal";
314 }
315 typeDecorations[words[0]] = words[2];
316 break;
317 case spirv::Decoration::LinkageAttributes: {
318 if (words.size() < 4) {
319 return emitError(unknownLoc, "OpDecorate with ")
320 << decorationName
321 << " needs at least 1 string and 1 integer literal";
322 }
323 // LinkageAttributes has two parameters ["linkageName", linkageType]
324 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
325 // "linkageName" is a stringliteral encoded as uint32_t,
326 // hence the size of name is variable length which results in words.size()
327 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
328 // 3 + ceildiv(strlen(name), 4).
329 unsigned wordIndex = 2;
330 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
331 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
332 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
333 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
334 StringAttr::get(context, linkageName), linkageTypeAttr);
335 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
336 break;
337 }
338 case spirv::Decoration::Aliased:
339 case spirv::Decoration::AliasedPointer:
340 case spirv::Decoration::Block:
341 case spirv::Decoration::BufferBlock:
342 case spirv::Decoration::Flat:
343 case spirv::Decoration::NonReadable:
344 case spirv::Decoration::NonWritable:
345 case spirv::Decoration::NoPerspective:
346 case spirv::Decoration::NoSignedWrap:
347 case spirv::Decoration::NoUnsignedWrap:
348 case spirv::Decoration::RelaxedPrecision:
349 case spirv::Decoration::Restrict:
350 case spirv::Decoration::RestrictPointer:
351 case spirv::Decoration::NoContraction:
352 case spirv::Decoration::Constant:
353 case spirv::Decoration::Invariant:
354 case spirv::Decoration::Patch:
355 case spirv::Decoration::Coherent:
356 if (words.size() != 2) {
357 return emitError(unknownLoc, "OpDecorate with ")
358 << decorationName << " needs a single target <id>";
359 }
360 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
361 break;
362 case spirv::Decoration::CacheControlLoadINTEL: {
363 LogicalResult res = deserializeCacheControlDecoration<
364 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
365 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
366 "load");
367 if (failed(res))
368 return res;
369 break;
370 }
371 case spirv::Decoration::CacheControlStoreINTEL: {
372 LogicalResult res = deserializeCacheControlDecoration<
373 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
374 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
375 "store");
376 if (failed(res))
377 return res;
378 break;
379 }
380 default:
381 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
382 }
383 return success();
384}
385
386LogicalResult
387spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
388 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
389 if (words.size() < 3) {
390 return emitError(unknownLoc,
391 "OpMemberDecorate must have at least 3 operands");
392 }
393
394 auto decoration = static_cast<spirv::Decoration>(words[2]);
395 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
396 return emitError(unknownLoc,
397 " missing offset specification in OpMemberDecorate with "
398 "Offset decoration");
399 }
400 ArrayRef<uint32_t> decorationOperands;
401 if (words.size() > 3) {
402 decorationOperands = words.slice(3);
403 }
404 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
405 return success();
406}
407
408LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
409 if (words.size() < 3) {
410 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
411 }
412 unsigned wordIndex = 2;
413 auto name = decodeStringLiteral(words, wordIndex);
414 if (wordIndex != words.size()) {
415 return emitError(unknownLoc,
416 "unexpected trailing words in OpMemberName instruction");
417 }
418 memberNameMap[words[0]][words[1]] = name;
419 return success();
420}
421
423 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
424 if (!decorations.contains(argID)) {
425 argAttrs[argIndex] = DictionaryAttr::get(context, {});
426 return success();
427 }
428
429 spirv::DecorationAttr foundDecorationAttr;
430 for (NamedAttribute decAttr : decorations[argID]) {
431 for (auto decoration :
432 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
433 spirv::Decoration::AliasedPointer,
434 spirv::Decoration::RestrictPointer}) {
435
436 if (decAttr.getName() !=
437 getSymbolDecoration(stringifyDecoration(decoration)))
438 continue;
439
440 if (foundDecorationAttr)
441 return emitError(unknownLoc,
442 "more than one Aliased/Restrict decorations for "
443 "function argument with result <id> ")
444 << argID;
445
446 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
447 break;
448 }
449
450 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
451 spirv::Decoration::RelaxedPrecision))) {
452 // TODO: Current implementation supports only one decoration per function
453 // parameter so RelaxedPrecision cannot be applied at the same time as,
454 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
455 // combination of decoration allowed by the spec to be supported.
456 if (foundDecorationAttr)
457 return emitError(unknownLoc, "already found a decoration for function "
458 "argument with result <id> ")
459 << argID;
460
461 foundDecorationAttr = spirv::DecorationAttr::get(
462 context, spirv::Decoration::RelaxedPrecision);
463 }
464 }
465
466 if (!foundDecorationAttr)
467 return emitError(unknownLoc, "unimplemented decoration support for "
468 "function argument with result <id> ")
469 << argID;
470
471 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
472 foundDecorationAttr);
473 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
474 return success();
475}
476
477LogicalResult
479 if (curFunction) {
480 return emitError(unknownLoc, "found function inside function");
481 }
482
483 // Get the result type
484 if (operands.size() != 4) {
485 return emitError(unknownLoc, "OpFunction must have 4 parameters");
486 }
487 Type resultType = getType(operands[0]);
488 if (!resultType) {
489 return emitError(unknownLoc, "undefined result type from <id> ")
490 << operands[0];
491 }
492
493 uint32_t fnID = operands[1];
494 if (funcMap.count(fnID)) {
495 return emitError(unknownLoc, "duplicate function definition/declaration");
496 }
497
498 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
499 if (!fnControl) {
500 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
501 }
502
503 Type fnType = getType(operands[3]);
504 if (!fnType || !isa<FunctionType>(fnType)) {
505 return emitError(unknownLoc, "unknown function type from <id> ")
506 << operands[3];
507 }
508 auto functionType = cast<FunctionType>(fnType);
509
510 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
511 (functionType.getNumResults() == 1 &&
512 functionType.getResult(0) != resultType)) {
513 return emitError(unknownLoc, "mismatch in function type ")
514 << functionType << " and return type " << resultType << " specified";
515 }
516
517 std::string fnName = getFunctionSymbol(fnID);
518 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
519 functionType, fnControl.value());
520 // Processing other function attributes.
521 if (decorations.count(fnID)) {
522 for (auto attr : decorations[fnID].getAttrs()) {
523 funcOp->setAttr(attr.getName(), attr.getValue());
524 }
525 }
526 curFunction = funcMap[fnID] = funcOp;
527 auto *entryBlock = funcOp.addEntryBlock();
528 LLVM_DEBUG({
529 logger.startLine()
530 << "//===-------------------------------------------===//\n";
531 logger.startLine() << "[fn] name: " << fnName << "\n";
532 logger.startLine() << "[fn] type: " << fnType << "\n";
533 logger.startLine() << "[fn] ID: " << fnID << "\n";
534 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
535 logger.indent();
536 });
537
538 SmallVector<Attribute> argAttrs;
539 argAttrs.resize(functionType.getNumInputs());
540
541 // Parse the op argument instructions
542 if (functionType.getNumInputs()) {
543 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
544 auto argType = functionType.getInput(i);
545 spirv::Opcode opcode = spirv::Opcode::OpNop;
546 ArrayRef<uint32_t> operands;
547 if (failed(sliceInstruction(opcode, operands,
548 spirv::Opcode::OpFunctionParameter))) {
549 return failure();
550 }
551 if (opcode != spirv::Opcode::OpFunctionParameter) {
552 return emitError(
553 unknownLoc,
554 "missing OpFunctionParameter instruction for argument ")
555 << i;
556 }
557 if (operands.size() != 2) {
558 return emitError(
559 unknownLoc,
560 "expected result type and result <id> for OpFunctionParameter");
561 }
562 auto argDefinedType = getType(operands[0]);
563 if (!argDefinedType || argDefinedType != argType) {
564 return emitError(unknownLoc,
565 "mismatch in argument type between function type "
566 "definition ")
567 << functionType << " and argument type definition "
568 << argDefinedType << " at argument " << i;
569 }
570 if (getValue(operands[1])) {
571 return emitError(unknownLoc, "duplicate definition of result <id> ")
572 << operands[1];
573 }
574 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
575 return failure();
576 }
577
578 auto argValue = funcOp.getArgument(i);
579 valueMap[operands[1]] = argValue;
580 }
581 }
582
583 if (llvm::any_of(argAttrs, [](Attribute attr) {
584 auto argAttr = cast<DictionaryAttr>(attr);
585 return !argAttr.empty();
586 }))
587 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
588
589 // entryBlock is needed to access the arguments, Once that is done, we can
590 // erase the block for functions with 'Import' LinkageAttributes, since these
591 // are essentially function declarations, so they have no body.
592 auto linkageAttr = funcOp.getLinkageAttributes();
593 auto hasImportLinkage =
594 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
595 spirv::LinkageType::Import);
596 if (hasImportLinkage)
597 funcOp.eraseBody();
598
599 // RAII guard to reset the insertion point to the module's region after
600 // deserializing the body of this function.
601 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
602
603 spirv::Opcode opcode = spirv::Opcode::OpNop;
604 ArrayRef<uint32_t> instOperands;
605
606 // Special handling for the entry block. We need to make sure it starts with
607 // an OpLabel instruction. The entry block takes the same parameters as the
608 // function. All other blocks do not take any parameter. We have already
609 // created the entry block, here we need to register it to the correct label
610 // <id>.
611 if (failed(sliceInstruction(opcode, instOperands,
612 spirv::Opcode::OpFunctionEnd))) {
613 return failure();
614 }
615 if (opcode == spirv::Opcode::OpFunctionEnd) {
616 return processFunctionEnd(instOperands);
617 }
618 if (opcode != spirv::Opcode::OpLabel) {
619 return emitError(unknownLoc, "a basic block must start with OpLabel");
620 }
621 if (instOperands.size() != 1) {
622 return emitError(unknownLoc, "OpLabel should only have result <id>");
623 }
624 blockMap[instOperands[0]] = entryBlock;
625 if (failed(processLabel(instOperands))) {
626 return failure();
627 }
628
629 // Then process all the other instructions in the function until we hit
630 // OpFunctionEnd.
631 while (succeeded(sliceInstruction(opcode, instOperands,
632 spirv::Opcode::OpFunctionEnd)) &&
633 opcode != spirv::Opcode::OpFunctionEnd) {
634 if (failed(processInstruction(opcode, instOperands))) {
635 return failure();
636 }
637 }
638 if (opcode != spirv::Opcode::OpFunctionEnd) {
639 return failure();
640 }
641
642 return processFunctionEnd(instOperands);
643}
644
645LogicalResult
647 // Process OpFunctionEnd.
648 if (!operands.empty()) {
649 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
650 }
651
652 // Wire up block arguments from OpPhi instructions.
653 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
654 // ops.
655 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
656 return failure();
657 }
658
659 curBlock = nullptr;
660 curFunction = std::nullopt;
661
662 LLVM_DEBUG({
663 logger.unindent();
664 logger.startLine()
665 << "//===-------------------------------------------===//\n";
666 });
667 return success();
668}
669
670LogicalResult
672 if (operands.size() < 2) {
673 return emitError(unknownLoc,
674 "missing graph defintion in OpGraphEntryPointARM");
675 }
676
677 unsigned wordIndex = 0;
678 uint32_t graphID = operands[wordIndex++];
679 if (!graphMap.contains(graphID)) {
680 return emitError(unknownLoc,
681 "missing graph definition/declaration with id ")
682 << graphID;
683 }
684
685 spirv::GraphARMOp graphARM = graphMap[graphID];
686 StringRef name = decodeStringLiteral(operands, wordIndex);
687 graphARM.setSymName(name);
688 graphARM.setEntryPoint(true);
689
691 for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
692 if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
693 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
694 } else {
695 return emitError(unknownLoc, "undefined result <id> ")
696 << operands[wordIndex] << " while decoding OpGraphEntryPoint";
697 }
698 }
699
700 // RAII guard to reset the insertion point to previous value when done.
701 OpBuilder::InsertionGuard insertionGuard(opBuilder);
702 opBuilder.setInsertionPoint(graphARM);
703 spirv::GraphEntryPointARMOp::create(
704 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
705 opBuilder.getArrayAttr(interface));
706
707 return success();
708}
709
710LogicalResult
712 if (curGraph) {
713 return emitError(unknownLoc, "found graph inside graph");
714 }
715 // Get the result type.
716 if (operands.size() < 2) {
717 return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
718 }
719
720 Type type = getType(operands[0]);
721 if (!type || !isa<GraphType>(type)) {
722 return emitError(unknownLoc, "unknown graph type from <id> ")
723 << operands[0];
724 }
725 auto graphType = cast<GraphType>(type);
726 if (graphType.getNumResults() <= 0) {
727 return emitError(unknownLoc, "expected at least one result");
728 }
729
730 uint32_t graphID = operands[1];
731 if (graphMap.count(graphID)) {
732 return emitError(unknownLoc, "duplicate graph definition/declaration");
733 }
734
735 std::string graphName = getGraphSymbol(graphID);
736 auto graphOp =
737 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
738 curGraph = graphMap[graphID] = graphOp;
739 Block *entryBlock = graphOp.addEntryBlock();
740 LLVM_DEBUG({
741 logger.startLine()
742 << "//===-------------------------------------------===//\n";
743 logger.startLine() << "[graph] name: " << graphName << "\n";
744 logger.startLine() << "[graph] type: " << graphType << "\n";
745 logger.startLine() << "[graph] ID: " << graphID << "\n";
746 logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
747 logger.indent();
748 });
749
750 // Parse the op argument instructions.
751 for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
752 spirv::Opcode opcode;
753 ArrayRef<uint32_t> operands;
754 if (failed(sliceInstruction(opcode, operands,
755 spirv::Opcode::OpGraphInputARM))) {
756 return failure();
757 }
758 if (operands.size() != 3) {
759 return emitError(unknownLoc, "expected result type, result <id> and "
760 "input index for OpGraphInputARM");
761 }
762
763 Type argDefinedType = getType(operands[0]);
764 if (!argDefinedType) {
765 return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
766 }
767
768 if (argDefinedType != argType) {
769 return emitError(unknownLoc,
770 "mismatch in argument type between graph type "
771 "definition ")
772 << graphType << " and argument type definition " << argDefinedType
773 << " at argument " << index;
774 }
775 if (getValue(operands[1])) {
776 return emitError(unknownLoc, "duplicate definition of result <id> ")
777 << operands[1];
778 }
779
780 IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
781 if (!inputIndexAttr) {
782 return emitError(unknownLoc,
783 "unable to read inputIndex value from constant op ")
784 << operands[2];
785 }
786 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
787 valueMap[operands[1]] = argValue;
788 }
789
790 graphOutputs.resize(graphType.getNumResults());
791
792 // RAII guard to reset the insertion point to the module's region after
793 // deserializing the body of this function.
794 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
795
796 blockMap[graphID] = entryBlock;
797 if (failed(createGraphBlock(graphID))) {
798 return failure();
799 }
800
801 // Process all the instructions in the graph until and including
802 // OpGraphEndARM.
803 spirv::Opcode opcode;
804 ArrayRef<uint32_t> instOperands;
805 do {
806 if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
807 return failure();
808 }
809
810 if (failed(processInstruction(opcode, instOperands))) {
811 return failure();
812 }
813 } while (opcode != spirv::Opcode::OpGraphEndARM);
814
815 return success();
816}
817
818LogicalResult
820 if (operands.size() != 2) {
821 return emitError(
822 unknownLoc,
823 "expected value id and output index for OpGraphSetOutputARM");
824 }
825
826 uint32_t id = operands[0];
827 Value value = getValue(id);
828 if (!value) {
829 return emitError(unknownLoc, "could not find result <id> ") << id;
830 }
831
832 IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
833 if (!outputIndexAttr) {
834 return emitError(unknownLoc,
835 "unable to read outputIndex value from constant op ")
836 << operands[1];
837 }
838 graphOutputs[outputIndexAttr.getInt()] = value;
839 return success();
840}
841
842LogicalResult
844 // Create GraphOutputsARM instruction.
845 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
846
847 // Process OpGraphEndARM.
848 if (!operands.empty()) {
849 return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
850 }
851
852 curBlock = nullptr;
853 curGraph = std::nullopt;
854 graphOutputs.clear();
855
856 LLVM_DEBUG({
857 logger.unindent();
858 logger.startLine()
859 << "//===-------------------------------------------===//\n";
860 });
861 return success();
862}
863
864std::optional<std::pair<Attribute, Type>>
866 auto constIt = constantMap.find(id);
867 if (constIt == constantMap.end())
868 return std::nullopt;
869 return constIt->getSecond();
870}
871
872std::optional<std::pair<Attribute, Type>>
874 if (auto it = constantCompositeReplicateMap.find(id);
875 it != constantCompositeReplicateMap.end())
876 return it->second;
877 return std::nullopt;
878}
879
880std::optional<spirv::SpecConstOperationMaterializationInfo>
882 auto constIt = specConstOperationMap.find(id);
883 if (constIt == specConstOperationMap.end())
884 return std::nullopt;
885 return constIt->getSecond();
886}
887
889 auto funcName = nameMap.lookup(id).str();
890 if (funcName.empty()) {
891 funcName = "spirv_fn_" + std::to_string(id);
892 }
893 return funcName;
894}
895
896std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
897 std::string graphName = nameMap.lookup(id).str();
898 if (graphName.empty()) {
899 graphName = "spirv_graph_" + std::to_string(id);
900 }
901 return graphName;
902}
903
905 auto constName = nameMap.lookup(id).str();
906 if (constName.empty()) {
907 constName = "spirv_spec_const_" + std::to_string(id);
908 }
909 return constName;
910}
911
912spirv::SpecConstantOp
914 TypedAttr defaultValue) {
915 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
916 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
917 defaultValue);
918 if (decorations.count(resultID)) {
919 for (auto attr : decorations[resultID].getAttrs())
920 op->setAttr(attr.getName(), attr.getValue());
921 }
922 specConstMap[resultID] = op;
923 return op;
924}
925
926std::optional<spirv::GraphConstantARMOpMaterializationInfo>
928 auto graphConstIt = graphConstantMap.find(id);
929 if (graphConstIt == graphConstantMap.end())
930 return std::nullopt;
931 return graphConstIt->getSecond();
932}
933
934LogicalResult
936 unsigned wordIndex = 0;
937 if (operands.size() < 3) {
938 return emitError(
939 unknownLoc,
940 "OpVariable needs at least 3 operands, type, <id> and storage class");
941 }
942
943 // Result Type.
944 auto type = getType(operands[wordIndex]);
945 if (!type) {
946 return emitError(unknownLoc, "unknown result type <id> : ")
947 << operands[wordIndex];
948 }
949 auto ptrType = dyn_cast<spirv::PointerType>(type);
950 if (!ptrType) {
951 return emitError(unknownLoc,
952 "expected a result type <id> to be a spirv.ptr, found : ")
953 << type;
954 }
955 wordIndex++;
956
957 // Result <id>.
958 auto variableID = operands[wordIndex];
959 auto variableName = nameMap.lookup(variableID).str();
960 if (variableName.empty()) {
961 variableName = "spirv_var_" + std::to_string(variableID);
962 }
963 wordIndex++;
964
965 // Storage class.
966 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
967 if (ptrType.getStorageClass() != storageClass) {
968 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
969 << type << " and that specified in OpVariable instruction : "
970 << stringifyStorageClass(storageClass);
971 }
972 wordIndex++;
973
974 // Initializer.
975 FlatSymbolRefAttr initializer = nullptr;
976
977 if (wordIndex < operands.size()) {
978 Operation *op = nullptr;
979
980 if (auto initOp = getGlobalVariable(operands[wordIndex]))
981 op = initOp;
982 else if (auto initOp = getSpecConstant(operands[wordIndex]))
983 op = initOp;
984 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
985 op = initOp;
986 else
987 return emitError(unknownLoc, "unknown <id> ")
988 << operands[wordIndex] << "used as initializer";
989
990 initializer = SymbolRefAttr::get(op);
991 wordIndex++;
992 }
993 if (wordIndex != operands.size()) {
994 return emitError(unknownLoc,
995 "found more operands than expected when deserializing "
996 "OpVariable instruction, only ")
997 << wordIndex << " of " << operands.size() << " processed";
998 }
999 auto loc = createFileLineColLoc(opBuilder);
1000 auto varOp = spirv::GlobalVariableOp::create(
1001 opBuilder, loc, TypeAttr::get(type),
1002 opBuilder.getStringAttr(variableName), initializer);
1003
1004 // Decorations.
1005 if (decorations.count(variableID)) {
1006 for (auto attr : decorations[variableID].getAttrs())
1007 varOp->setAttr(attr.getName(), attr.getValue());
1008 }
1009 globalVariableMap[variableID] = varOp;
1010 return success();
1011}
1012
1013IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1014 auto constInfo = getConstant(id);
1015 if (!constInfo) {
1016 return nullptr;
1017 }
1018 return dyn_cast<IntegerAttr>(constInfo->first);
1019}
1020
1021LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1022 if (operands.size() < 2) {
1023 return emitError(unknownLoc, "OpName needs at least 2 operands");
1024 }
1025
1026 unsigned wordIndex = 1;
1027 StringRef name = decodeStringLiteral(operands, wordIndex);
1028 if (wordIndex != operands.size()) {
1029 return emitError(unknownLoc,
1030 "unexpected trailing words in OpName instruction");
1031 }
1032
1033 // In SPIRV it's valid for multiple OpName instructions to refer to the same
1034 // <id>. Use a "last one wins" approach to resolve such cases.
1035 nameMap.emplace_or_assign(operands[0], name);
1036
1037 return success();
1038}
1039
1040//===----------------------------------------------------------------------===//
1041// Type
1042//===----------------------------------------------------------------------===//
1043
1044LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1045 ArrayRef<uint32_t> operands) {
1046 if (operands.empty()) {
1047 return emitError(unknownLoc, "type instruction with opcode ")
1048 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1049 }
1050
1051 /// TODO: Types might be forward declared in some instructions and need to be
1052 /// handled appropriately.
1053 if (typeMap.count(operands[0])) {
1054 return emitError(unknownLoc, "duplicate definition for result <id> ")
1055 << operands[0];
1056 }
1057
1058 switch (opcode) {
1059 case spirv::Opcode::OpTypeVoid:
1060 if (operands.size() != 1)
1061 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1062 typeMap[operands[0]] = opBuilder.getNoneType();
1063 break;
1064 case spirv::Opcode::OpTypeBool:
1065 if (operands.size() != 1)
1066 return emitError(unknownLoc, "OpTypeBool must have no parameters");
1067 typeMap[operands[0]] = opBuilder.getI1Type();
1068 break;
1069 case spirv::Opcode::OpTypeInt: {
1070 if (operands.size() != 3)
1071 return emitError(
1072 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1073
1074 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1075 // to preserve or validate.
1076 // 0 indicates unsigned, or no signedness semantics
1077 // 1 indicates signed semantics."
1078 //
1079 // So we cannot differentiate signless and unsigned integers; always use
1080 // signless semantics for such cases.
1081 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1082 : IntegerType::SignednessSemantics::Signless;
1083 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1084 } break;
1085 case spirv::Opcode::OpTypeFloat: {
1086 if (operands.size() != 2 && operands.size() != 3)
1087 return emitError(unknownLoc,
1088 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1089 "or 3 operands (type, bitwidth, encoding), but got ")
1090 << operands.size();
1091 uint32_t bitWidth = operands[1];
1092
1093 Type floatTy;
1094 if (operands.size() == 2) {
1095 switch (bitWidth) {
1096 case 16:
1097 floatTy = opBuilder.getF16Type();
1098 break;
1099 case 32:
1100 floatTy = opBuilder.getF32Type();
1101 break;
1102 case 64:
1103 floatTy = opBuilder.getF64Type();
1104 break;
1105 default:
1106 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1107 << bitWidth;
1108 }
1109 }
1110
1111 if (operands.size() == 3) {
1112 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1113 bitWidth == 16)
1114 floatTy = opBuilder.getBF16Type();
1115 else if (spirv::FPEncoding(operands[2]) ==
1116 spirv::FPEncoding::Float8E4M3EXT &&
1117 bitWidth == 8)
1118 floatTy = opBuilder.getF8E4M3FNType();
1119 else if (spirv::FPEncoding(operands[2]) ==
1120 spirv::FPEncoding::Float8E5M2EXT &&
1121 bitWidth == 8)
1122 floatTy = opBuilder.getF8E5M2Type();
1123 else
1124 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1125 << operands[2] << " and bitWidth " << bitWidth;
1126 }
1127
1128 typeMap[operands[0]] = floatTy;
1129 } break;
1130 case spirv::Opcode::OpTypeVector: {
1131 if (operands.size() != 3) {
1132 return emitError(
1133 unknownLoc,
1134 "OpTypeVector must have element type and count parameters");
1135 }
1136 Type elementTy = getType(operands[1]);
1137 if (!elementTy) {
1138 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1139 << operands[1];
1140 }
1141 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1142 } break;
1143 case spirv::Opcode::OpTypePointer: {
1144 return processOpTypePointer(operands);
1145 } break;
1146 case spirv::Opcode::OpTypeArray:
1147 return processArrayType(operands);
1148 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1149 return processCooperativeMatrixTypeKHR(operands);
1150 case spirv::Opcode::OpTypeFunction:
1151 return processFunctionType(operands);
1152 case spirv::Opcode::OpTypeImage:
1153 return processImageType(operands);
1154 case spirv::Opcode::OpTypeSampler:
1155 return processSamplerType(operands);
1156 case spirv::Opcode::OpTypeSampledImage:
1157 return processSampledImageType(operands);
1158 case spirv::Opcode::OpTypeRuntimeArray:
1159 return processRuntimeArrayType(operands);
1160 case spirv::Opcode::OpTypeStruct:
1161 return processStructType(operands);
1162 case spirv::Opcode::OpTypeMatrix:
1163 return processMatrixType(operands);
1164 case spirv::Opcode::OpTypeTensorARM:
1165 return processTensorARMType(operands);
1166 case spirv::Opcode::OpTypeGraphARM:
1167 return processGraphTypeARM(operands);
1168 default:
1169 return emitError(unknownLoc, "unhandled type instruction");
1170 }
1171 return success();
1172}
1173
1174LogicalResult
1176 if (operands.size() != 3)
1177 return emitError(unknownLoc, "OpTypePointer must have two parameters");
1178
1179 auto pointeeType = getType(operands[2]);
1180 if (!pointeeType)
1181 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1182 << operands[2];
1183
1184 uint32_t typePointerID = operands[0];
1185 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1186 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1187
1188 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1189 deferredStructIt != std::end(deferredStructTypesInfos);) {
1190 for (auto *unresolvedMemberIt =
1191 std::begin(deferredStructIt->unresolvedMemberTypes);
1192 unresolvedMemberIt !=
1193 std::end(deferredStructIt->unresolvedMemberTypes);) {
1194 if (unresolvedMemberIt->first == typePointerID) {
1195 // The newly constructed pointer type can resolve one of the
1196 // deferred struct type members; update the memberTypes list and
1197 // clean the unresolvedMemberTypes list accordingly.
1198 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1199 typeMap[typePointerID];
1200 unresolvedMemberIt =
1201 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1202 } else {
1203 ++unresolvedMemberIt;
1204 }
1205 }
1206
1207 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1208 // All deferred struct type members are now resolved, set the struct body.
1209 auto structType = deferredStructIt->deferredStructType;
1210
1211 assert(structType && "expected a spirv::StructType");
1212 assert(structType.isIdentified() && "expected an indentified struct");
1213
1214 if (failed(structType.trySetBody(
1215 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1216 deferredStructIt->memberDecorationsInfo,
1217 deferredStructIt->structDecorationsInfo)))
1218 return failure();
1219
1220 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1221 } else {
1222 ++deferredStructIt;
1223 }
1224 }
1225
1226 return success();
1227}
1228
1229LogicalResult
1231 if (operands.size() != 3) {
1232 return emitError(unknownLoc,
1233 "OpTypeArray must have element type and count parameters");
1234 }
1235
1236 Type elementTy = getType(operands[1]);
1237 if (!elementTy) {
1238 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1239 << operands[1];
1240 }
1241
1242 unsigned count = 0;
1243 // TODO: The count can also come frome a specialization constant.
1244 auto countInfo = getConstant(operands[2]);
1245 if (!countInfo) {
1246 return emitError(unknownLoc, "OpTypeArray count <id> ")
1247 << operands[2] << "can only come from normal constant right now";
1248 }
1249
1250 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1251 count = intVal.getValue().getZExtValue();
1252 } else {
1253 return emitError(unknownLoc, "OpTypeArray count must come from a "
1254 "scalar integer constant instruction");
1255 }
1256
1257 typeMap[operands[0]] = spirv::ArrayType::get(
1258 elementTy, count, typeDecorations.lookup(operands[0]));
1259 return success();
1260}
1261
1262LogicalResult
1264 assert(!operands.empty() && "No operands for processing function type");
1265 if (operands.size() == 1) {
1266 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1267 }
1268 auto returnType = getType(operands[1]);
1269 if (!returnType) {
1270 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1271 }
1272 SmallVector<Type, 1> argTypes;
1273 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1274 auto ty = getType(operands[i]);
1275 if (!ty) {
1276 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1277 }
1278 argTypes.push_back(ty);
1279 }
1280 ArrayRef<Type> returnTypes;
1281 if (!isVoidType(returnType)) {
1282 returnTypes = llvm::ArrayRef(returnType);
1283 }
1284 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1285 return success();
1286}
1287
1289 ArrayRef<uint32_t> operands) {
1290 if (operands.size() != 6) {
1291 return emitError(unknownLoc,
1292 "OpTypeCooperativeMatrixKHR must have element type, "
1293 "scope, row and column parameters, and use");
1294 }
1295
1296 Type elementTy = getType(operands[1]);
1297 if (!elementTy) {
1298 return emitError(unknownLoc,
1299 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1300 << operands[1];
1301 }
1302
1303 std::optional<spirv::Scope> scope =
1304 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1305 if (!scope) {
1306 return emitError(
1307 unknownLoc,
1308 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1309 << operands[2];
1310 }
1311
1312 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1313 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1314 IntegerAttr useAttr = getConstantInt(operands[5]);
1315
1316 if (!rowsAttr)
1317 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1318 "undefined constant <id> ")
1319 << operands[3];
1320
1321 if (!columnsAttr)
1322 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1323 "references undefined constant <id> ")
1324 << operands[4];
1325
1326 if (!useAttr)
1327 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1328 "undefined constant <id> ")
1329 << operands[5];
1330
1331 unsigned rows = rowsAttr.getInt();
1332 unsigned columns = columnsAttr.getInt();
1333
1334 std::optional<spirv::CooperativeMatrixUseKHR> use =
1335 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1336 if (!use) {
1337 return emitError(
1338 unknownLoc,
1339 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1340 << operands[5];
1341 }
1342
1343 typeMap[operands[0]] =
1344 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1345 return success();
1346}
1347
1348LogicalResult
1350 if (operands.size() != 2) {
1351 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1352 }
1353 Type memberType = getType(operands[1]);
1354 if (!memberType) {
1355 return emitError(unknownLoc,
1356 "OpTypeRuntimeArray references undefined <id> ")
1357 << operands[1];
1358 }
1359 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1360 memberType, typeDecorations.lookup(operands[0]));
1361 return success();
1362}
1363
1364LogicalResult
1366 // TODO: Find a way to handle identified structs when debug info is stripped.
1367
1368 if (operands.empty()) {
1369 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1370 }
1371
1372 if (operands.size() == 1) {
1373 // Handle empty struct.
1374 typeMap[operands[0]] =
1375 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1376 return success();
1377 }
1378
1379 // First element is operand ID, second element is member index in the struct.
1380 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1381 SmallVector<Type, 4> memberTypes;
1382
1383 for (auto op : llvm::drop_begin(operands, 1)) {
1384 Type memberType = getType(op);
1385 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1386
1387 if (!memberType && !typeForwardPtr)
1388 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1389 << op;
1390
1391 if (!memberType)
1392 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1393
1394 memberTypes.push_back(memberType);
1395 }
1396
1399 if (memberDecorationMap.count(operands[0])) {
1400 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1401 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1402 if (allMemberDecorations.count(memberIndex)) {
1403 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1404 // Check for offset.
1405 if (memberDecoration.first == spirv::Decoration::Offset) {
1406 // If offset info is empty, resize to the number of members;
1407 if (offsetInfo.empty()) {
1408 offsetInfo.resize(memberTypes.size());
1409 }
1410 offsetInfo[memberIndex] = memberDecoration.second[0];
1411 } else {
1412 auto intType = mlir::IntegerType::get(context, 32);
1413 if (!memberDecoration.second.empty()) {
1414 memberDecorationsInfo.emplace_back(
1415 memberIndex, memberDecoration.first,
1416 IntegerAttr::get(intType, memberDecoration.second[0]));
1417 } else {
1418 memberDecorationsInfo.emplace_back(
1419 memberIndex, memberDecoration.first, UnitAttr::get(context));
1420 }
1421 }
1422 }
1423 }
1424 }
1425 }
1426
1428 if (decorations.count(operands[0])) {
1429 NamedAttrList &allDecorations = decorations[operands[0]];
1430 for (NamedAttribute &decorationAttr : allDecorations) {
1431 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1432 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1433 assert(decoration.has_value());
1434 structDecorationsInfo.emplace_back(decoration.value(),
1435 decorationAttr.getValue());
1436 }
1437 }
1438
1439 uint32_t structID = operands[0];
1440 std::string structIdentifier = nameMap.lookup(structID).str();
1441
1442 if (structIdentifier.empty()) {
1443 assert(unresolvedMemberTypes.empty() &&
1444 "didn't expect unresolved member types");
1445 typeMap[structID] = spirv::StructType::get(
1446 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1447 } else {
1448 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1449 typeMap[structID] = structTy;
1450
1451 if (!unresolvedMemberTypes.empty())
1452 deferredStructTypesInfos.push_back(
1453 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1454 memberDecorationsInfo, structDecorationsInfo});
1455 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1456 memberDecorationsInfo,
1457 structDecorationsInfo)))
1458 return failure();
1459 }
1460
1461 // TODO: Update StructType to have member name as attribute as
1462 // well.
1463 return success();
1464}
1465
1466LogicalResult
1468 if (operands.size() != 3) {
1469 // Three operands are needed: result_id, column_type, and column_count
1470 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1471 " (result_id, column_type, and column_count)");
1472 }
1473 // Matrix columns must be of vector type
1474 Type elementTy = getType(operands[1]);
1475 if (!elementTy) {
1476 return emitError(unknownLoc,
1477 "OpTypeMatrix references undefined column type.")
1478 << operands[1];
1479 }
1480
1481 uint32_t colsCount = operands[2];
1482 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1483 return success();
1484}
1485
1486LogicalResult
1488 unsigned size = operands.size();
1489 if (size < 2 || size > 4)
1490 return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1491 "(result_id, element_type, (rank), (shape)) ")
1492 << size;
1493
1494 Type elementTy = getType(operands[1]);
1495 if (!elementTy)
1496 return emitError(unknownLoc,
1497 "OpTypeTensorARM references undefined element type ")
1498 << operands[1];
1499
1500 if (size == 2) {
1501 typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1502 return success();
1503 }
1504
1505 IntegerAttr rankAttr = getConstantInt(operands[2]);
1506 if (!rankAttr)
1507 return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1508 "scalar integer constant instruction");
1509 unsigned rank = rankAttr.getValue().getZExtValue();
1510 if (size == 3) {
1511 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1512 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1513 return success();
1514 }
1515
1516 std::optional<std::pair<Attribute, Type>> shapeInfo =
1517 getConstant(operands[3]);
1518 if (!shapeInfo)
1519 return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1520 "constant instruction of type OpTypeArray");
1521
1522 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1524 for (auto dimAttr : shapeArrayAttr.getValue()) {
1525 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1526 if (!dimIntAttr)
1527 return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1528 "dimension size");
1529 shape.push_back(dimIntAttr.getValue().getSExtValue());
1530 }
1531 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1532 return success();
1533}
1534
1535LogicalResult
1537 unsigned size = operands.size();
1538 if (size < 2) {
1539 return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1540 "(result_id, num_inputs, (inout0_type, "
1541 "inout1_type, ...))")
1542 << size;
1543 }
1544 uint32_t numInputs = operands[1];
1545 SmallVector<Type, 1> argTypes;
1546 SmallVector<Type, 1> returnTypes;
1547 for (unsigned i = 2; i < size; ++i) {
1548 Type inOutTy = getType(operands[i]);
1549 if (!inOutTy) {
1550 return emitError(unknownLoc,
1551 "OpTypeGraphARM references undefined element type.")
1552 << operands[i];
1553 }
1554 if (i - 2 >= numInputs) {
1555 returnTypes.push_back(inOutTy);
1556 } else {
1557 argTypes.push_back(inOutTy);
1558 }
1559 }
1560 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1561 return success();
1562}
1563
1564LogicalResult
1566 if (operands.size() != 2)
1567 return emitError(unknownLoc,
1568 "OpTypeForwardPointer instruction must have two operands");
1569
1570 typeForwardPointerIDs.insert(operands[0]);
1571 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1572 // instruction that defines the actual type.
1573
1574 return success();
1575}
1576
1577LogicalResult
1579 // TODO: Add support for Access Qualifier.
1580 if (operands.size() != 8)
1581 return emitError(
1582 unknownLoc,
1583 "OpTypeImage with non-eight operands are not supported yet");
1584
1585 Type elementTy = getType(operands[1]);
1586 if (!elementTy)
1587 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1588 << operands[1];
1589
1590 auto dim = spirv::symbolizeDim(operands[2]);
1591 if (!dim)
1592 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1593 << operands[2];
1594
1595 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1596 if (!depthInfo)
1597 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1598 << operands[3];
1599
1600 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1601 if (!arrayedInfo)
1602 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1603 << operands[4];
1604
1605 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1606 if (!samplingInfo)
1607 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1608
1609 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1610 if (!samplerUseInfo)
1611 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1612 << operands[6];
1613
1614 auto format = spirv::symbolizeImageFormat(operands[7]);
1615 if (!format)
1616 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1617 << operands[7];
1618
1619 typeMap[operands[0]] = spirv::ImageType::get(
1620 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1621 samplingInfo.value(), samplerUseInfo.value(), format.value());
1622 return success();
1623}
1624
1625LogicalResult
1627 if (operands.size() != 2)
1628 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1629
1630 Type elementTy = getType(operands[1]);
1631 if (!elementTy)
1632 return emitError(unknownLoc,
1633 "OpTypeSampledImage references undefined <id>: ")
1634 << operands[1];
1635
1636 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1637 return success();
1638}
1639
1640LogicalResult
1642 if (operands.size() != 1)
1643 return emitError(unknownLoc, "OpTypeSampler must have no parameters");
1644
1645 typeMap[operands[0]] = spirv::SamplerType::get(context);
1646 return success();
1647}
1648
1649//===----------------------------------------------------------------------===//
1650// Constant
1651//===----------------------------------------------------------------------===//
1652
1654 bool isSpec) {
1655 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1656
1657 if (operands.size() < 2) {
1658 return emitError(unknownLoc)
1659 << opname << " must have type <id> and result <id>";
1660 }
1661 if (operands.size() < 3) {
1662 return emitError(unknownLoc)
1663 << opname << " must have at least 1 more parameter";
1664 }
1665
1666 Type resultType = getType(operands[0]);
1667 if (!resultType) {
1668 return emitError(unknownLoc, "undefined result type from <id> ")
1669 << operands[0];
1670 }
1671
1672 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1673 if (bitwidth == 64) {
1674 if (operands.size() == 4) {
1675 return success();
1676 }
1677 return emitError(unknownLoc)
1678 << opname << " should have 2 parameters for 64-bit values";
1679 }
1680 if (bitwidth <= 32) {
1681 if (operands.size() == 3) {
1682 return success();
1683 }
1684
1685 return emitError(unknownLoc)
1686 << opname
1687 << " should have 1 parameter for values with no more than 32 bits";
1688 }
1689 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1690 << bitwidth;
1691 };
1692
1693 auto resultID = operands[1];
1694
1695 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1696 auto bitwidth = intType.getWidth();
1697 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1698 return failure();
1699 }
1700
1701 APInt value;
1702 if (bitwidth == 64) {
1703 // 64-bit integers are represented with two SPIR-V words. According to
1704 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1705 // literal’s low-order words appear first."
1706 struct DoubleWord {
1707 uint32_t word1;
1708 uint32_t word2;
1709 } words = {operands[2], operands[3]};
1710 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1711 } else if (bitwidth <= 32) {
1712 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1713 /*implicitTrunc=*/true);
1714 }
1715
1716 auto attr = opBuilder.getIntegerAttr(intType, value);
1717
1718 if (isSpec) {
1719 createSpecConstant(unknownLoc, resultID, attr);
1720 } else {
1721 // For normal constants, we just record the attribute (and its type) for
1722 // later materialization at use sites.
1723 constantMap.try_emplace(resultID, attr, intType);
1724 }
1725
1726 return success();
1727 }
1728
1729 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1730 auto bitwidth = floatType.getWidth();
1731 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1732 return failure();
1733 }
1734
1735 APFloat value(0.f);
1736 if (floatType.isF64()) {
1737 // Double values are represented with two SPIR-V words. According to
1738 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1739 // literal’s low-order words appear first."
1740 struct DoubleWord {
1741 uint32_t word1;
1742 uint32_t word2;
1743 } words = {operands[2], operands[3]};
1744 value = APFloat(llvm::bit_cast<double>(words));
1745 } else if (floatType.isF32()) {
1746 value = APFloat(llvm::bit_cast<float>(operands[2]));
1747 } else if (floatType.isF16()) {
1748 APInt data(16, operands[2]);
1749 value = APFloat(APFloat::IEEEhalf(), data);
1750 } else if (floatType.isBF16()) {
1751 APInt data(16, operands[2]);
1752 value = APFloat(APFloat::BFloat(), data);
1753 } else if (floatType.isF8E4M3FN()) {
1754 APInt data(8, operands[2]);
1755 value = APFloat(APFloat::Float8E4M3FN(), data);
1756 } else if (floatType.isF8E5M2()) {
1757 APInt data(8, operands[2]);
1758 value = APFloat(APFloat::Float8E5M2(), data);
1759 }
1760
1761 auto attr = opBuilder.getFloatAttr(floatType, value);
1762 if (isSpec) {
1763 createSpecConstant(unknownLoc, resultID, attr);
1764 } else {
1765 // For normal constants, we just record the attribute (and its type) for
1766 // later materialization at use sites.
1767 constantMap.try_emplace(resultID, attr, floatType);
1768 }
1769
1770 return success();
1771 }
1772
1773 return emitError(unknownLoc, "OpConstant can only generate values of "
1774 "scalar integer or floating-point type");
1775}
1776
1778 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1779 if (operands.size() != 2) {
1780 return emitError(unknownLoc, "Op")
1781 << (isSpec ? "Spec" : "") << "Constant"
1782 << (isTrue ? "True" : "False")
1783 << " must have type <id> and result <id>";
1784 }
1785
1786 auto attr = opBuilder.getBoolAttr(isTrue);
1787 auto resultID = operands[1];
1788 if (isSpec) {
1789 createSpecConstant(unknownLoc, resultID, attr);
1790 } else {
1791 // For normal constants, we just record the attribute (and its type) for
1792 // later materialization at use sites.
1793 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1794 }
1795
1796 return success();
1797}
1798
1799LogicalResult
1801 if (operands.size() < 2) {
1802 return emitError(unknownLoc,
1803 "OpConstantComposite must have type <id> and result <id>");
1804 }
1805 if (operands.size() < 3) {
1806 return emitError(unknownLoc,
1807 "OpConstantComposite must have at least 1 parameter");
1808 }
1809
1810 Type resultType = getType(operands[0]);
1811 if (!resultType) {
1812 return emitError(unknownLoc, "undefined result type from <id> ")
1813 << operands[0];
1814 }
1815
1817 elements.reserve(operands.size() - 2);
1818 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1819 auto elementInfo = getConstant(operands[i]);
1820 if (!elementInfo) {
1821 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1822 << operands[i] << " must come from a normal constant";
1823 }
1824 elements.push_back(elementInfo->first);
1825 }
1826
1827 auto resultID = operands[1];
1828 if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1829 SmallVector<Attribute> flattenedElems;
1830 for (Attribute element : elements) {
1831 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1832 for (auto value : denseElemAttr.getValues<Attribute>())
1833 flattenedElems.push_back(value);
1834 } else {
1835 flattenedElems.push_back(element);
1836 }
1837 }
1838 auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1839 constantMap.try_emplace(resultID, attr, tensorType);
1840 } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1841 auto attr = DenseElementsAttr::get(shapedType, elements);
1842 // For normal constants, we just record the attribute (and its type) for
1843 // later materialization at use sites.
1844 constantMap.try_emplace(resultID, attr, shapedType);
1845 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1846 auto attr = opBuilder.getArrayAttr(elements);
1847 constantMap.try_emplace(resultID, attr, resultType);
1848 } else {
1849 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1850 << resultType;
1851 }
1852
1853 return success();
1854}
1855
1857 ArrayRef<uint32_t> operands) {
1858 if (operands.size() != 3) {
1859 return emitError(
1860 unknownLoc,
1861 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1862 << operands.size();
1863 }
1864
1865 Type resultType = getType(operands[0]);
1866 if (!resultType) {
1867 return emitError(unknownLoc, "undefined result type from <id> ")
1868 << operands[0];
1869 }
1870
1871 auto compositeType = dyn_cast<CompositeType>(resultType);
1872 if (!compositeType) {
1873 return emitError(unknownLoc,
1874 "result type from <id> is not a composite type")
1875 << operands[0];
1876 }
1877
1878 uint32_t resultID = operands[1];
1879 uint32_t constantID = operands[2];
1880
1881 std::optional<std::pair<Attribute, Type>> constantInfo =
1882 getConstant(constantID);
1883 if (constantInfo.has_value()) {
1884 constantCompositeReplicateMap.try_emplace(
1885 resultID, constantInfo.value().first, resultType);
1886 return success();
1887 }
1888
1889 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1891 if (replicatedConstantCompositeInfo.has_value()) {
1892 constantCompositeReplicateMap.try_emplace(
1893 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1894 return success();
1895 }
1896
1897 return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1898 << constantID
1899 << " must come from a normal constant or a "
1900 "OpConstantCompositeReplicateEXT";
1901}
1902
1903LogicalResult
1905 if (operands.size() < 2) {
1906 return emitError(
1907 unknownLoc,
1908 "OpSpecConstantComposite must have type <id> and result <id>");
1909 }
1910 if (operands.size() < 3) {
1911 return emitError(unknownLoc,
1912 "OpSpecConstantComposite must have at least 1 parameter");
1913 }
1914
1915 Type resultType = getType(operands[0]);
1916 if (!resultType) {
1917 return emitError(unknownLoc, "undefined result type from <id> ")
1918 << operands[0];
1919 }
1920
1921 auto resultID = operands[1];
1922 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1923
1925 elements.reserve(operands.size() - 2);
1926 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1927 auto elementInfo = getSpecConstant(operands[i]);
1928 elements.push_back(SymbolRefAttr::get(elementInfo));
1929 }
1930
1931 auto op = spirv::SpecConstantCompositeOp::create(
1932 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1933 opBuilder.getArrayAttr(elements));
1934 specConstCompositeMap[resultID] = op;
1935
1936 return success();
1937}
1938
1940 ArrayRef<uint32_t> operands) {
1941 if (operands.size() != 3) {
1942 return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1943 "3 operands but found ")
1944 << operands.size();
1945 }
1946
1947 Type resultType = getType(operands[0]);
1948 if (!resultType) {
1949 return emitError(unknownLoc, "undefined result type from <id> ")
1950 << operands[0];
1951 }
1952
1953 auto compositeType = dyn_cast<CompositeType>(resultType);
1954 if (!compositeType) {
1955 return emitError(unknownLoc,
1956 "result type from <id> is not a composite type")
1957 << operands[0];
1958 }
1959
1960 uint32_t resultID = operands[1];
1961
1962 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1963 spirv::SpecConstantOp constituentSpecConstantOp =
1964 getSpecConstant(operands[2]);
1965 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1966 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1967 SymbolRefAttr::get(constituentSpecConstantOp));
1968
1969 specConstCompositeReplicateMap[resultID] = op;
1970
1971 return success();
1972}
1973
1974LogicalResult
1976 if (operands.size() < 3)
1977 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1978 "result <id>, and operand opcode");
1979
1980 uint32_t resultTypeID = operands[0];
1981
1982 if (!getType(resultTypeID))
1983 return emitError(unknownLoc, "undefined result type from <id> ")
1984 << resultTypeID;
1985
1986 uint32_t resultID = operands[1];
1987 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1988 auto emplaceResult = specConstOperationMap.try_emplace(
1989 resultID,
1991 enclosedOpcode, resultTypeID,
1992 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1993
1994 if (!emplaceResult.second)
1995 return emitError(unknownLoc, "value with <id>: ")
1996 << resultID << " is probably defined before.";
1997
1998 return success();
1999}
2000
2002 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
2003 ArrayRef<uint32_t> enclosedOpOperands) {
2004
2005 Type resultType = getType(resultTypeID);
2006
2007 // Instructions wrapped by OpSpecConstantOp need an ID for their
2008 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
2009 // dialect wrapped op. For that purpose, a new value map is created and "fake"
2010 // ID in that map is assigned to the result of the enclosed instruction. Note
2011 // that there is no need to update this fake ID since we only need to
2012 // reference the created Value for the enclosed op from the spv::YieldOp
2013 // created later in this method (both of which are the only values in their
2014 // region: the SpecConstantOperation's region). If we encounter another
2015 // SpecConstantOperation in the module, we simply re-use the fake ID since the
2016 // previous Value assigned to it isn't visible in the current scope anyway.
2017 DenseMap<uint32_t, Value> newValueMap;
2018 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2019 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
2020
2021 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
2022 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2023 enclosedOpResultTypeAndOperands.push_back(fakeID);
2024 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2025 enclosedOpOperands.end());
2026
2027 // Process enclosed instruction before creating the enclosing
2028 // specConstantOperation (and its region). This way, references to constants,
2029 // global variables, and spec constants will be materialized outside the new
2030 // op's region. For more info, see Deserializer::getValue's implementation.
2031 if (failed(
2032 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2033 return Value();
2034
2035 // Since the enclosed op is emitted in the current block, split it in a
2036 // separate new block.
2037 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2038
2039 auto loc = createFileLineColLoc(opBuilder);
2040 auto specConstOperationOp =
2041 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2042
2043 Region &body = specConstOperationOp.getBody();
2044 // Move the new block into SpecConstantOperation's body.
2045 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2046 Region::iterator(enclosedBlock));
2047 Block &block = body.back();
2048
2049 // RAII guard to reset the insertion point to the module's region after
2050 // deserializing the body of the specConstantOperation.
2051 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2052 opBuilder.setInsertionPointToEnd(&block);
2053
2054 spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2055 return specConstOperationOp.getResult();
2056}
2057
2058LogicalResult
2060 if (operands.size() != 2) {
2061 return emitError(unknownLoc,
2062 "OpConstantNull must only have type <id> and result <id>");
2063 }
2064
2065 Type resultType = getType(operands[0]);
2066 if (!resultType) {
2067 return emitError(unknownLoc, "undefined result type from <id> ")
2068 << operands[0];
2069 }
2070
2071 auto resultID = operands[1];
2072 Attribute attr;
2073 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2074 attr = opBuilder.getZeroAttr(resultType);
2075 } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2076 if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2077 attr = DenseElementsAttr::get(tensorType, element);
2078 }
2079
2080 if (attr) {
2081 // For normal constants, we just record the attribute (and its type) for
2082 // later materialization at use sites.
2083 constantMap.try_emplace(resultID, attr, resultType);
2084 return success();
2085 }
2086
2087 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2088 << resultType;
2089}
2090
2091LogicalResult
2093 if (operands.size() < 3) {
2094 return emitError(unknownLoc)
2095 << "OpGraphConstantARM must have at least 2 operands";
2096 }
2097
2098 Type resultType = getType(operands[0]);
2099 if (!resultType) {
2100 return emitError(unknownLoc, "undefined result type from <id> ")
2101 << operands[0];
2102 }
2103
2104 uint32_t resultID = operands[1];
2105
2106 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2107 return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2108 }
2109
2110 APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2111 Type i32Ty = opBuilder.getIntegerType(32);
2112 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2113 graphConstantMap.try_emplace(
2114 resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2115
2116 return success();
2117}
2118
2119//===----------------------------------------------------------------------===//
2120// Control flow
2121//===----------------------------------------------------------------------===//
2122
2124 if (auto *block = getBlock(id)) {
2125 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2126 << " @ " << block << "\n");
2127 return block;
2128 }
2129
2130 // We don't know where this block will be placed finally (in a
2131 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2132 // function for now and sort out the proper place later.
2133 auto *block = curFunction->addBlock();
2134 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2135 << " @ " << block << "\n");
2136 return blockMap[id] = block;
2137}
2138
2140 if (!curBlock) {
2141 return emitError(unknownLoc, "OpBranch must appear inside a block");
2142 }
2143
2144 if (operands.size() != 1) {
2145 return emitError(unknownLoc, "OpBranch must take exactly one target label");
2146 }
2147
2148 auto *target = getOrCreateBlock(operands[0]);
2149 auto loc = createFileLineColLoc(opBuilder);
2150 // The preceding instruction for the OpBranch instruction could be an
2151 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2152 // the same OpLine information.
2153 spirv::BranchOp::create(opBuilder, loc, target);
2154
2156 return success();
2157}
2158
2159LogicalResult
2161 if (!curBlock) {
2162 return emitError(unknownLoc,
2163 "OpBranchConditional must appear inside a block");
2164 }
2165
2166 if (operands.size() != 3 && operands.size() != 5) {
2167 return emitError(unknownLoc,
2168 "OpBranchConditional must have condition, true label, "
2169 "false label, and optionally two branch weights");
2170 }
2171
2172 auto condition = getValue(operands[0]);
2173 auto *trueBlock = getOrCreateBlock(operands[1]);
2174 auto *falseBlock = getOrCreateBlock(operands[2]);
2175
2176 std::optional<std::pair<uint32_t, uint32_t>> weights;
2177 if (operands.size() == 5) {
2178 weights = std::make_pair(operands[3], operands[4]);
2179 }
2180 // The preceding instruction for the OpBranchConditional instruction could be
2181 // an OpSelectionMerge instruction, in this case they will have the same
2182 // OpLine information.
2183 auto loc = createFileLineColLoc(opBuilder);
2184 spirv::BranchConditionalOp::create(
2185 opBuilder, loc, condition, trueBlock,
2186 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2187 /*falseArguments=*/ArrayRef<Value>(), weights);
2188
2190 return success();
2191}
2192
2194 if (!curFunction) {
2195 return emitError(unknownLoc, "OpLabel must appear inside a function");
2196 }
2197
2198 if (operands.size() != 1) {
2199 return emitError(unknownLoc, "OpLabel should only have result <id>");
2200 }
2201
2202 auto labelID = operands[0];
2203 // We may have forward declared this block.
2204 auto *block = getOrCreateBlock(labelID);
2205 LLVM_DEBUG(logger.startLine()
2206 << "[block] populating block " << block << "\n");
2207 // If we have seen this block, make sure it was just a forward declaration.
2208 assert(block->empty() && "re-deserialize the same block!");
2209
2210 opBuilder.setInsertionPointToStart(block);
2211 blockMap[labelID] = curBlock = block;
2212
2213 return success();
2214}
2215
2216LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2217 if (!curGraph) {
2218 return emitError(unknownLoc, "a graph block must appear inside a graph");
2219 }
2220
2221 // We may have forward declared this block.
2222 Block *block = getOrCreateBlock(graphID);
2223 LLVM_DEBUG(logger.startLine()
2224 << "[block] populating block " << block << "\n");
2225 // If we have seen this block, make sure it was just a forward declaration.
2226 assert(block->empty() && "re-deserialize the same block!");
2227
2228 opBuilder.setInsertionPointToStart(block);
2229 blockMap[graphID] = curBlock = block;
2230
2231 return success();
2232}
2233
2234LogicalResult
2236 if (!curBlock) {
2237 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2238 }
2239
2240 if (operands.size() < 2) {
2241 return emitError(
2242 unknownLoc,
2243 "OpSelectionMerge must specify merge target and selection control");
2244 }
2245
2246 auto *mergeBlock = getOrCreateBlock(operands[0]);
2247 auto loc = createFileLineColLoc(opBuilder);
2248 auto selectionControl = operands[1];
2249
2250 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2251 .second) {
2252 return emitError(
2253 unknownLoc,
2254 "a block cannot have more than one OpSelectionMerge instruction");
2255 }
2256
2257 return success();
2258}
2259
2260LogicalResult
2262 if (!curBlock) {
2263 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2264 }
2265
2266 if (operands.size() < 3) {
2267 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2268 "continue target and loop control");
2269 }
2270
2271 auto *mergeBlock = getOrCreateBlock(operands[0]);
2272 auto *continueBlock = getOrCreateBlock(operands[1]);
2273 auto loc = createFileLineColLoc(opBuilder);
2274 uint32_t loopControl = operands[2];
2275
2276 if (!blockMergeInfo
2277 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2278 .second) {
2279 return emitError(
2280 unknownLoc,
2281 "a block cannot have more than one OpLoopMerge instruction");
2282 }
2283
2284 return success();
2285}
2286
2288 if (!curBlock) {
2289 return emitError(unknownLoc, "OpPhi must appear in a block");
2290 }
2291
2292 if (operands.size() < 4) {
2293 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2294 "and variable-parent pairs");
2295 }
2296
2297 // Create a block argument for this OpPhi instruction.
2298 Type blockArgType = getType(operands[0]);
2299 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2300 valueMap[operands[1]] = blockArg;
2301 LLVM_DEBUG(logger.startLine()
2302 << "[phi] created block argument " << blockArg
2303 << " id = " << operands[1] << " of type " << blockArgType << "\n");
2304
2305 // For each (value, predecessor) pair, insert the value to the predecessor's
2306 // blockPhiInfo entry so later we can fix the block argument there.
2307 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2308 uint32_t value = operands[i];
2309 Block *predecessor = getOrCreateBlock(operands[i + 1]);
2310 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2311 blockPhiInfo[predecessorTargetPair].push_back(value);
2312 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2313 << " with arg id = " << value << "\n");
2314 }
2315
2316 return success();
2317}
2318
2320 if (!curBlock)
2321 return emitError(unknownLoc, "OpSwitch must appear in a block");
2322
2323 if (operands.size() < 2)
2324 return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2325 "a default target");
2326
2327 if (operands.size() % 2)
2328 return emitError(unknownLoc,
2329 "OpSwitch must at have an even number of operands: "
2330 "selector, default target and any number of literal and "
2331 "label <id> pairs");
2332
2333 Value selector = getValue(operands[0]);
2334 Block *defaultBlock = getOrCreateBlock(operands[1]);
2335 Location loc = createFileLineColLoc(opBuilder);
2336
2337 SmallVector<int32_t> literals;
2338 SmallVector<Block *> blocks;
2339 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2340 literals.push_back(operands[i]);
2341 blocks.push_back(getOrCreateBlock(operands[i + 1]));
2342 }
2343
2344 SmallVector<ValueRange> targetOperands(blocks.size(), {});
2345 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2346 ArrayRef<Value>(), literals, blocks, targetOperands);
2347
2348 return success();
2349}
2350
2351namespace {
2352/// A class for putting all blocks in a structured selection/loop in a
2353/// spirv.mlir.selection/spirv.mlir.loop op.
2354class ControlFlowStructurizer {
2355public:
2356#ifndef NDEBUG
2357 ControlFlowStructurizer(Location loc, uint32_t control,
2358 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2359 Block *merge, Block *cont,
2360 llvm::ScopedPrinter &logger)
2361 : location(loc), control(control), blockMergeInfo(mergeInfo),
2362 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2363 logger(logger) {}
2364#else
2365 ControlFlowStructurizer(Location loc, uint32_t control,
2366 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2367 Block *merge, Block *cont)
2368 : location(loc), control(control), blockMergeInfo(mergeInfo),
2369 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2370#endif
2371
2372 /// Structurizes the loop at the given `headerBlock`.
2373 ///
2374 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2375 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2376 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2377 /// method will also update `mergeInfo` by remapping all blocks inside to the
2378 /// newly cloned ones inside structured control flow op's regions.
2379 LogicalResult structurize();
2380
2381private:
2382 /// Creates a new spirv.mlir.selection op at the beginning of the
2383 /// `mergeBlock`.
2384 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2385
2386 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2387 spirv::LoopOp createLoopOp(uint32_t loopControl);
2388
2389 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2390 void collectBlocksInConstruct();
2391
2392 Location location;
2393 uint32_t control;
2394
2395 spirv::BlockMergeInfoMap &blockMergeInfo;
2396
2397 Block *headerBlock;
2398 Block *mergeBlock;
2399 Block *continueBlock; // nullptr for spirv.mlir.selection
2400
2401 SetVector<Block *> constructBlocks;
2402
2403#ifndef NDEBUG
2404 /// A logger used to emit information during the deserialzation process.
2405 llvm::ScopedPrinter &logger;
2406#endif
2407};
2408} // namespace
2409
2410spirv::SelectionOp
2411ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2412 // Create a builder and set the insertion point to the beginning of the
2413 // merge block so that the newly created SelectionOp will be inserted there.
2414 OpBuilder builder(&mergeBlock->front());
2415
2416 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2417 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2418 selectionOp.addMergeBlock(builder);
2419
2420 return selectionOp;
2421}
2422
2423spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2424 // Create a builder and set the insertion point to the beginning of the
2425 // merge block so that the newly created LoopOp will be inserted there.
2426 OpBuilder builder(&mergeBlock->front());
2427
2428 auto control = static_cast<spirv::LoopControl>(loopControl);
2429 auto loopOp = spirv::LoopOp::create(builder, location, control);
2430 loopOp.addEntryAndMergeBlock(builder);
2431
2432 return loopOp;
2433}
2434
2435void ControlFlowStructurizer::collectBlocksInConstruct() {
2436 assert(constructBlocks.empty() && "expected empty constructBlocks");
2437
2438 // Put the header block in the work list first.
2439 constructBlocks.insert(headerBlock);
2440
2441 // For each item in the work list, add its successors excluding the merge
2442 // block.
2443 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2444 for (auto *successor : constructBlocks[i]->getSuccessors())
2445 if (successor != mergeBlock)
2446 constructBlocks.insert(successor);
2447 }
2448}
2449
2450LogicalResult ControlFlowStructurizer::structurize() {
2451 Operation *op = nullptr;
2452 bool isLoop = continueBlock != nullptr;
2453 if (isLoop) {
2454 if (auto loopOp = createLoopOp(control))
2455 op = loopOp.getOperation();
2456 } else {
2457 if (auto selectionOp = createSelectionOp(control))
2458 op = selectionOp.getOperation();
2459 }
2460 if (!op)
2461 return failure();
2462 Region &body = op->getRegion(0);
2463
2464 IRMapping mapper;
2465 // All references to the old merge block should be directed to the
2466 // selection/loop merge block in the SelectionOp/LoopOp's region.
2467 mapper.map(mergeBlock, &body.back());
2468
2469 collectBlocksInConstruct();
2470
2471 // We've identified all blocks belonging to the selection/loop's region. Now
2472 // need to "move" them into the selection/loop. Instead of really moving the
2473 // blocks, in the following we copy them and remap all values and branches.
2474 // This is because:
2475 // * Inserting a block into a region requires the block not in any region
2476 // before. But selections/loops can nest so we can create selection/loop ops
2477 // in a nested manner, which means some blocks may already be in a
2478 // selection/loop region when to be moved again.
2479 // * It's much trickier to fix up the branches into and out of the loop's
2480 // region: we need to treat not-moved blocks and moved blocks differently:
2481 // Not-moved blocks jumping to the loop header block need to jump to the
2482 // merge point containing the new loop op but not the loop continue block's
2483 // back edge. Moved blocks jumping out of the loop need to jump to the
2484 // merge block inside the loop region but not other not-moved blocks.
2485 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2486 // logic.
2487
2488 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2489 // block in this loop construct.
2490 OpBuilder builder(body);
2491 for (auto *block : constructBlocks) {
2492 // Create a block and insert it before the selection/loop merge block in the
2493 // SelectionOp/LoopOp's region.
2494 auto *newBlock = builder.createBlock(&body.back());
2495 mapper.map(block, newBlock);
2496 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2497 << " from block " << block << "\n");
2498 if (!isFnEntryBlock(block)) {
2499 for (BlockArgument blockArg : block->getArguments()) {
2500 auto newArg =
2501 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2502 mapper.map(blockArg, newArg);
2503 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2504 << blockArg << " to " << newArg << "\n");
2505 }
2506 } else {
2507 LLVM_DEBUG(logger.startLine()
2508 << "[cf] block " << block << " is a function entry block\n");
2509 }
2510
2511 for (auto &op : *block)
2512 newBlock->push_back(op.clone(mapper));
2513 }
2514
2515 // Go through all ops and remap the operands.
2516 auto remapOperands = [&](Operation *op) {
2517 for (auto &operand : op->getOpOperands())
2518 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2519 operand.set(mappedOp);
2520 for (auto &succOp : op->getBlockOperands())
2521 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2522 succOp.set(mappedOp);
2523 };
2524 for (auto &block : body)
2525 block.walk(remapOperands);
2526
2527 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2528 // the selection/loop construct into its region. Next we need to fix the
2529 // connections between this new SelectionOp/LoopOp with existing blocks.
2530
2531 // All existing incoming branches should go to the merge block, where the
2532 // SelectionOp/LoopOp resides right now.
2533 headerBlock->replaceAllUsesWith(mergeBlock);
2534
2535 LLVM_DEBUG({
2536 logger.startLine() << "[cf] after cloning and fixing references:\n";
2537 headerBlock->getParentOp()->print(logger.getOStream());
2538 logger.startLine() << "\n";
2539 });
2540
2541 if (isLoop) {
2542 if (!mergeBlock->args_empty()) {
2543 return mergeBlock->getParentOp()->emitError(
2544 "OpPhi in loop merge block unsupported");
2545 }
2546
2547 // The loop header block may have block arguments. Since now we place the
2548 // loop op inside the old merge block, we need to make sure the old merge
2549 // block has the same block argument list.
2550 for (BlockArgument blockArg : headerBlock->getArguments())
2551 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2552
2553 // If the loop header block has block arguments, make sure the spirv.Branch
2554 // op matches.
2555 SmallVector<Value, 4> blockArgs;
2556 if (!headerBlock->args_empty())
2557 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2558
2559 // The loop entry block should have a unconditional branch jumping to the
2560 // loop header block.
2561 builder.setInsertionPointToEnd(&body.front());
2562 spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2563 ArrayRef<Value>(blockArgs));
2564 }
2565
2566 // Values defined inside the selection region that need to be yielded outside
2567 // the region.
2568 SmallVector<Value> valuesToYield;
2569 // Outside uses of values that were sunk into the selection region. Those uses
2570 // will be replaced with values returned by the SelectionOp.
2571 SmallVector<Value> outsideUses;
2572
2573 // Move block arguments of the original block (`mergeBlock`) into the merge
2574 // block inside the selection (`body.back()`). Values produced by block
2575 // arguments will be yielded by the selection region. We do not update uses or
2576 // erase original block arguments yet. It will be done later in the code.
2577 //
2578 // Code below is not executed for loops as it would interfere with the logic
2579 // above. Currently block arguments in the merge block are not supported, but
2580 // instead, the code above copies those arguments from the header block into
2581 // the merge block. As such, running the code would yield those copied
2582 // arguments that is most likely not a desired behaviour. This may need to be
2583 // revisited in the future.
2584 if (!isLoop)
2585 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2586 // Create new block arguments in the last block ("merge block") of the
2587 // selection region. We create one argument for each argument in
2588 // `mergeBlock`. This new value will need to be yielded, and the original
2589 // value replaced, so add them to appropriate vectors.
2590 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2591 valuesToYield.push_back(body.back().getArguments().back());
2592 outsideUses.push_back(blockArg);
2593 }
2594
2595 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2596 // cleaned up.
2597 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2598 // First we need to drop all operands' references inside all blocks. This is
2599 // needed because we can have blocks referencing SSA values from one another.
2600 for (auto *block : constructBlocks)
2601 block->dropAllReferences();
2602
2603 // All internal uses should be removed from original blocks by now, so
2604 // whatever is left is an outside use and will need to be yielded from
2605 // the newly created selection / loop region.
2606 for (Block *block : constructBlocks) {
2607 for (Operation &op : *block) {
2608 if (!op.use_empty())
2609 for (Value result : op.getResults()) {
2610 valuesToYield.push_back(mapper.lookupOrNull(result));
2611 outsideUses.push_back(result);
2612 }
2613 }
2614 for (BlockArgument &arg : block->getArguments()) {
2615 if (!arg.use_empty()) {
2616 valuesToYield.push_back(mapper.lookupOrNull(arg));
2617 outsideUses.push_back(arg);
2618 }
2619 }
2620 }
2621
2622 assert(valuesToYield.size() == outsideUses.size());
2623
2624 // If we need to yield any values from the selection / loop region we will
2625 // take care of it here.
2626 if (!valuesToYield.empty()) {
2627 LLVM_DEBUG(logger.startLine()
2628 << "[cf] yielding values from the selection / loop region\n");
2629
2630 // Update `mlir.merge` with values to be yield.
2631 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2632 Operation *merge = llvm::getSingleElement(mergeOps);
2633 assert(merge);
2634 merge->setOperands(valuesToYield);
2635
2636 // MLIR does not allow changing the number of results of an operation, so
2637 // we create a new SelectionOp / LoopOp with required list of results and
2638 // move the region from the initial SelectionOp / LoopOp. The initial
2639 // operation is then removed. Since we move the region to the new op all
2640 // links between blocks and remapping we have previously done should be
2641 // preserved.
2642 builder.setInsertionPoint(&mergeBlock->front());
2643
2644 Operation *newOp = nullptr;
2645
2646 if (isLoop)
2647 newOp = spirv::LoopOp::create(builder, location,
2648 TypeRange(ValueRange(outsideUses)),
2649 static_cast<spirv::LoopControl>(control));
2650 else
2651 newOp = spirv::SelectionOp::create(
2652 builder, location, TypeRange(ValueRange(outsideUses)),
2653 static_cast<spirv::SelectionControl>(control));
2654
2655 newOp->getRegion(0).takeBody(body);
2656
2657 // Remove initial op and swap the pointer to the newly created one.
2658 op->erase();
2659 op = newOp;
2660
2661 // Update all outside uses to use results of the SelectionOp / LoopOp and
2662 // remove block arguments from the original merge block.
2663 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2664 outsideUses[i].replaceAllUsesWith(op->getResult(i));
2665
2666 // We do not support block arguments in loop merge block. Also running this
2667 // function with loop would break some of the loop specific code above
2668 // dealing with block arguments.
2669 if (!isLoop)
2670 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2671 }
2672
2673 // Check that whether some op in the to-be-erased blocks still has uses. Those
2674 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2675 // region. We cannot handle such cases given that once a value is sinked into
2676 // the SelectionOp/LoopOp's region, there is no escape for it.
2677 for (auto *block : constructBlocks) {
2678 if (!block->use_empty())
2679 return emitError(block->getParent()->getLoc(),
2680 "failed control flow structurization: "
2681 "block has uses outside of the "
2682 "enclosing selection/loop construct");
2683 for (Operation &op : *block)
2684 if (!op.use_empty())
2685 return op.emitOpError("failed control flow structurization: value has "
2686 "uses outside of the "
2687 "enclosing selection/loop construct");
2688 for (BlockArgument &arg : block->getArguments())
2689 if (!arg.use_empty())
2690 return emitError(arg.getLoc(), "failed control flow structurization: "
2691 "block argument has uses outside of the "
2692 "enclosing selection/loop construct");
2693 }
2694
2695 // Then erase all old blocks.
2696 for (auto *block : constructBlocks) {
2697 // We've cloned all blocks belonging to this construct into the structured
2698 // control flow op's region. Among these blocks, some may compose another
2699 // selection/loop. If so, they will be recorded within blockMergeInfo.
2700 // We need to update the pointers there to the newly remapped ones so we can
2701 // continue structurizing them later.
2702 //
2703 // We need to walk each block as constructBlocks do not include blocks
2704 // internal to ops already structured within those blocks. It is not
2705 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2706 // inside already structured selections/loops get invalidated and needs
2707 // updating, however the following example code can cause a crash (depending
2708 // on the structuring order), when the most inner selection is being
2709 // structured after the outer selection and loop have been already
2710 // structured:
2711 //
2712 // spirv.mlir.for {
2713 // // ...
2714 // spirv.mlir.selection {
2715 // // ..
2716 // // A selection region that hasn't been yet structured!
2717 // // ..
2718 // }
2719 // // ...
2720 // }
2721 //
2722 // If the loop gets structured after the outer selection, but before the
2723 // inner selection. Moving the already structured selection inside the loop
2724 // will invalidate the mergeInfo of the region that is not yet structured.
2725 // Just going over constructBlocks will not check and updated header blocks
2726 // inside the already structured selection region. Walking block fixes that.
2727 //
2728 // TODO: If structuring was done in a fixed order starting with inner
2729 // most constructs this most likely not be an issue and the whole code
2730 // section could be removed. However, with the current non-deterministic
2731 // order this is not possible.
2732 //
2733 // TODO: The asserts in the following assumes input SPIR-V blob forms
2734 // correctly nested selection/loop constructs. We should relax this and
2735 // support error cases better.
2736 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2737 auto it = blockMergeInfo.find(block);
2738 if (it != blockMergeInfo.end()) {
2739 // Use the original location for nested selection/loop ops.
2740 Location loc = it->second.loc;
2741
2742 Block *newHeader = mapper.lookupOrNull(block);
2743 if (!newHeader)
2744 return emitError(loc, "failed control flow structurization: nested "
2745 "loop header block should be remapped!");
2746
2747 Block *newContinue = it->second.continueBlock;
2748 if (newContinue) {
2749 newContinue = mapper.lookupOrNull(newContinue);
2750 if (!newContinue)
2751 return emitError(loc, "failed control flow structurization: nested "
2752 "loop continue block should be remapped!");
2753 }
2754
2755 Block *newMerge = it->second.mergeBlock;
2756 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2757 newMerge = mappedTo;
2758
2759 // The iterator should be erased before adding a new entry into
2760 // blockMergeInfo to avoid iterator invalidation.
2761 blockMergeInfo.erase(it);
2762 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2763 newContinue);
2764 }
2765
2766 return WalkResult::advance();
2767 };
2768
2769 if (block->walk(updateMergeInfo).wasInterrupted())
2770 return failure();
2771
2772 // The structured selection/loop's entry block does not have arguments.
2773 // If the function's header block is also part of the structured control
2774 // flow, we cannot just simply erase it because it may contain arguments
2775 // matching the function signature and used by the cloned blocks.
2776 if (isFnEntryBlock(block)) {
2777 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2778 << " to only contain a spirv.Branch op\n");
2779 // Still keep the function entry block for the potential block arguments,
2780 // but replace all ops inside with a branch to the merge block.
2781 block->clear();
2782 builder.setInsertionPointToEnd(block);
2783 spirv::BranchOp::create(builder, location, mergeBlock);
2784 } else {
2785 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2786 block->erase();
2787 }
2788 }
2789
2790 LLVM_DEBUG(logger.startLine()
2791 << "[cf] after structurizing construct with header block "
2792 << headerBlock << ":\n"
2793 << *op << "\n");
2794
2795 return success();
2796}
2797
2799 LLVM_DEBUG({
2800 logger.startLine()
2801 << "//----- [phi] start wiring up block arguments -----//\n";
2802 logger.indent();
2803 });
2804
2805 OpBuilder::InsertionGuard guard(opBuilder);
2806
2807 for (const auto &info : blockPhiInfo) {
2808 Block *block = info.first.first;
2809 Block *target = info.first.second;
2810 const BlockPhiInfo &phiInfo = info.second;
2811 LLVM_DEBUG({
2812 logger.startLine() << "[phi] block " << block << "\n";
2813 logger.startLine() << "[phi] before creating block argument:\n";
2814 block->getParentOp()->print(logger.getOStream());
2815 logger.startLine() << "\n";
2816 });
2817
2818 // Set insertion point to before this block's terminator early because we
2819 // may materialize ops via getValue() call.
2820 auto *op = block->getTerminator();
2821 opBuilder.setInsertionPoint(op);
2822
2823 SmallVector<Value, 4> blockArgs;
2824 blockArgs.reserve(phiInfo.size());
2825 for (uint32_t valueId : phiInfo) {
2826 if (Value value = getValue(valueId)) {
2827 blockArgs.push_back(value);
2828 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2829 << " id = " << valueId << "\n");
2830 } else {
2831 return emitError(unknownLoc, "OpPhi references undefined value!");
2832 }
2833 }
2834
2835 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2836 // Replace the previous branch op with a new one with block arguments.
2837 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2838 branchOp.getTarget(), blockArgs);
2839 branchOp.erase();
2840 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2841 assert((branchCondOp.getTrueBlock() == target ||
2842 branchCondOp.getFalseBlock() == target) &&
2843 "expected target to be either the true or false target");
2844 if (target == branchCondOp.getTrueTarget())
2845 spirv::BranchConditionalOp::create(
2846 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2847 blockArgs, branchCondOp.getFalseBlockArguments(),
2848 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2849 branchCondOp.getFalseTarget());
2850 else
2851 spirv::BranchConditionalOp::create(
2852 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2853 branchCondOp.getTrueBlockArguments(), blockArgs,
2854 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2855 branchCondOp.getFalseBlock());
2856
2857 branchCondOp.erase();
2858 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2859 if (target == switchOp.getDefaultTarget()) {
2860 SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
2861 DenseIntElementsAttr literals =
2862 switchOp.getLiterals().value_or(DenseIntElementsAttr());
2863 spirv::SwitchOp::create(
2864 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2865 switchOp.getDefaultTarget(), blockArgs, literals,
2866 switchOp.getTargets(), targetOperands);
2867 switchOp.erase();
2868 } else {
2869 SuccessorRange targets = switchOp.getTargets();
2870 auto it = llvm::find(targets, target);
2871 assert(it != targets.end());
2872 size_t index = std::distance(targets.begin(), it);
2873 switchOp.getTargetOperandsMutable(index).assign(blockArgs);
2874 }
2875 } else {
2876 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2877 }
2878
2879 LLVM_DEBUG({
2880 logger.startLine() << "[phi] after creating block argument:\n";
2881 block->getParentOp()->print(logger.getOStream());
2882 logger.startLine() << "\n";
2883 });
2884 }
2885 blockPhiInfo.clear();
2886
2887 LLVM_DEBUG({
2888 logger.unindent();
2889 logger.startLine()
2890 << "//--- [phi] completed wiring up block arguments ---//\n";
2891 });
2892 return success();
2893}
2894
2896 // Create a copy, so we can modify keys in the original.
2897 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2898 for (auto [block, mergeInfo] : blockMergeInfoCopy) {
2899 // Skip processing loop regions. For loop regions continueBlock is non-null.
2900 if (mergeInfo.continueBlock)
2901 continue;
2902
2903 if (!block->mightHaveTerminator())
2904 continue;
2905
2906 Operation *terminator = block->getTerminator();
2907 assert(terminator);
2908
2909 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2910 continue;
2911
2912 // Check if the current header block is a merge block of another construct.
2913 bool splitHeaderMergeBlock = false;
2914 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2915 if (mergeInfo.mergeBlock == block)
2916 splitHeaderMergeBlock = true;
2917 }
2918
2919 // Do not split a block that only contains a conditional branch / switch,
2920 // unless it is also a merge block of another construct - in that case we
2921 // want to split the block. We do not want two constructs to share header /
2922 // merge block.
2923 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2924 Block *newBlock = block->splitBlock(terminator);
2925 OpBuilder builder(block, block->end());
2926 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2927
2928 // After splitting we need to update the map to use the new block as a
2929 // header.
2930 blockMergeInfo.erase(block);
2931 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2932 }
2933 }
2934
2935 return success();
2936}
2937
2939 if (!options.enableControlFlowStructurization) {
2940 LLVM_DEBUG(
2941 {
2942 logger.startLine()
2943 << "//----- [cf] skip structurizing control flow -----//\n";
2944 logger.indent();
2945 });
2946 return success();
2947 }
2948
2949 LLVM_DEBUG({
2950 logger.startLine()
2951 << "//----- [cf] start structurizing control flow -----//\n";
2952 logger.indent();
2953 });
2954
2955 LLVM_DEBUG({
2956 logger.startLine() << "[cf] split conditional blocks\n";
2957 logger.startLine() << "\n";
2958 });
2959
2960 if (failed(splitSelectionHeader())) {
2961 return failure();
2962 }
2963
2964 while (!blockMergeInfo.empty()) {
2965 Block *headerBlock = blockMergeInfo.begin()->first;
2966 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2967
2968 LLVM_DEBUG({
2969 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2970 headerBlock->print(logger.getOStream());
2971 logger.startLine() << "\n";
2972 });
2973
2974 auto *mergeBlock = mergeInfo.mergeBlock;
2975 assert(mergeBlock && "merge block cannot be nullptr");
2976 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2977 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2978 LLVM_DEBUG({
2979 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2980 mergeBlock->print(logger.getOStream());
2981 logger.startLine() << "\n";
2982 });
2983
2984 auto *continueBlock = mergeInfo.continueBlock;
2985 LLVM_DEBUG(if (continueBlock) {
2986 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2987 continueBlock->print(logger.getOStream());
2988 logger.startLine() << "\n";
2989 });
2990 // Erase this case before calling into structurizer, who will update
2991 // blockMergeInfo.
2992 blockMergeInfo.erase(blockMergeInfo.begin());
2993 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2994 blockMergeInfo, headerBlock,
2995 mergeBlock, continueBlock
2996#ifndef NDEBUG
2997 ,
2998 logger
2999#endif
3000 );
3001 if (failed(structurizer.structurize()))
3002 return failure();
3003 }
3004
3005 LLVM_DEBUG({
3006 logger.unindent();
3007 logger.startLine()
3008 << "//--- [cf] completed structurizing control flow ---//\n";
3009 });
3010 return success();
3011}
3012
3013//===----------------------------------------------------------------------===//
3014// Debug
3015//===----------------------------------------------------------------------===//
3016
3018 if (!debugLine)
3019 return unknownLoc;
3020
3021 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3022 if (fileName.empty())
3023 fileName = "<unknown>";
3024 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
3025 debugLine->column);
3026}
3027
3028LogicalResult
3030 // According to SPIR-V spec:
3031 // "This location information applies to the instructions physically
3032 // following this instruction, up to the first occurrence of any of the
3033 // following: the next end of block, the next OpLine instruction, or the next
3034 // OpNoLine instruction."
3035 if (operands.size() != 3)
3036 return emitError(unknownLoc, "OpLine must have 3 operands");
3037 debugLine = DebugLine{operands[0], operands[1], operands[2]};
3038 return success();
3039}
3040
3041void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
3042
3043LogicalResult
3045 if (operands.size() < 2)
3046 return emitError(unknownLoc, "OpString needs at least 2 operands");
3047
3048 if (!debugInfoMap.lookup(operands[0]).empty())
3049 return emitError(unknownLoc,
3050 "duplicate debug string found for result <id> ")
3051 << operands[0];
3052
3053 unsigned wordIndex = 1;
3054 StringRef debugString = decodeStringLiteral(operands, wordIndex);
3055 if (wordIndex != operands.size())
3056 return emitError(unknownLoc,
3057 "unexpected trailing words in OpString instruction");
3058
3059 debugInfoMap[operands[0]] = debugString;
3060 return success();
3061}
return success()
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:306
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
iterator begin()
Definition Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< BlockOperand > getBlockOperands()
Definition Operation.h:721
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:878
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
void print(raw_ostream &os, const OpPrintingFlags &flags={})
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, PropertyRef properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:66
result_range getResults()
Definition Operation.h:441
Operation * clone(IRMapping &mapper, const CloneOptions &options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
iterator end()
Definition Region.h:56
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:252
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult processSamplerType(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:148
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static SamplerType get(MLIRContext *context)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:125
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.