MLIR 23.0.0git
Deserializer.cpp
Go to the documentation of this file.
1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(unknownLoc,
120 "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(unknownLoc, "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
142#undef MIN_VERSION_CASE
143 default:
144 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145 << minorVersion;
146 }
147 } else {
148 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149 << majorVersion;
150 }
151
152 // TODO: generator number, bound, schema
153 curOffset = spirv::kHeaderWordCount;
154 return success();
155}
156
157LogicalResult
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc, "OpCapability must have one parameter");
161
162 auto cap = spirv::symbolizeCapability(operands[0]);
163 if (!cap)
164 return emitError(unknownLoc, "unknown capability: ") << operands[0];
165
166 capabilities.insert(*cap);
167 return success();
168}
169
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171 if (words.empty()) {
172 return emitError(
173 unknownLoc,
174 "OpExtension must have a literal string for the extension name");
175 }
176
177 unsigned wordIndex = 0;
178 StringRef extName = decodeStringLiteral(words, wordIndex);
179 if (wordIndex != words.size())
180 return emitError(unknownLoc,
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
183 if (!ext)
184 return emitError(unknownLoc, "unknown extension: ") << extName;
185
186 extensions.insert(*ext);
187 return success();
188}
189
190LogicalResult
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
193 return emitError(unknownLoc,
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
196 }
197
198 unsigned wordIndex = 1;
199 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200 if (wordIndex != words.size()) {
201 return emitError(unknownLoc,
202 "unexpected trailing words in OpExtInstImport");
203 }
204 return success();
205}
206
207void spirv::Deserializer::attachVCETriple() {
208 (*module)->setAttr(
209 spirv::ModuleOp::getVCETripleAttrName(),
210 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211 extensions.getArrayRef(), context));
212}
213
214LogicalResult
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc, "OpMemoryModel must have two operands");
218
219 (*module)->setAttr(
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel>(operands.front())));
223
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel>(operands.back())));
227
228 return success();
229}
230
231template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233 Location loc, OpBuilder &opBuilder,
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc, "OpDecorate with ")
238 << decorationName << " needs a cache control integer literal and a "
239 << cacheControlKind << " cache control literal";
240 }
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 if (auto attrList =
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250 return success();
251}
252
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254 // TODO: This function should also be auto-generated. For now, since only a
255 // few decorations are processed/handled in a meaningful manner, going with a
256 // manual implementation.
257 if (words.size() < 2) {
258 return emitError(
259 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260 }
261 auto decorationName =
262 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265 }
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (static_cast<spirv::Decoration>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc, "OpDecorate with ")
271 << decorationName << " needs a single integer literal";
272 }
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode>(words[2])));
276 break;
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc, "OpDecorate with ")
280 << decorationName << " needs a single integer literal";
281 }
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode>(words[2])));
285 break;
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 case spirv::Decoration::Location:
289 case spirv::Decoration::SpecId:
290 case spirv::Decoration::Index:
291 if (words.size() != 3) {
292 return emitError(unknownLoc, "OpDecorate with ")
293 << decorationName << " needs a single integer literal";
294 }
295 decorations[words[0]].set(
296 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
297 break;
298 case spirv::Decoration::BuiltIn:
299 if (words.size() != 3) {
300 return emitError(unknownLoc, "OpDecorate with ")
301 << decorationName << " needs a single integer literal";
302 }
303 decorations[words[0]].set(
304 symbol, opBuilder.getStringAttr(
305 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
306 break;
307 case spirv::Decoration::ArrayStride:
308 if (words.size() != 3) {
309 return emitError(unknownLoc, "OpDecorate with ")
310 << decorationName << " needs a single integer literal";
311 }
312 typeDecorations[words[0]] = words[2];
313 break;
314 case spirv::Decoration::LinkageAttributes: {
315 if (words.size() < 4) {
316 return emitError(unknownLoc, "OpDecorate with ")
317 << decorationName
318 << " needs at least 1 string and 1 integer literal";
319 }
320 // LinkageAttributes has two parameters ["linkageName", linkageType]
321 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
322 // "linkageName" is a stringliteral encoded as uint32_t,
323 // hence the size of name is variable length which results in words.size()
324 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
325 // 3 + ceildiv(strlen(name), 4).
326 unsigned wordIndex = 2;
327 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
328 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
329 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
330 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
331 StringAttr::get(context, linkageName), linkageTypeAttr);
332 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
333 break;
334 }
335 case spirv::Decoration::Aliased:
336 case spirv::Decoration::AliasedPointer:
337 case spirv::Decoration::Block:
338 case spirv::Decoration::BufferBlock:
339 case spirv::Decoration::Flat:
340 case spirv::Decoration::NonReadable:
341 case spirv::Decoration::NonWritable:
342 case spirv::Decoration::NoPerspective:
343 case spirv::Decoration::NoSignedWrap:
344 case spirv::Decoration::NoUnsignedWrap:
345 case spirv::Decoration::RelaxedPrecision:
346 case spirv::Decoration::Restrict:
347 case spirv::Decoration::RestrictPointer:
348 case spirv::Decoration::NoContraction:
349 case spirv::Decoration::Constant:
350 case spirv::Decoration::Invariant:
351 case spirv::Decoration::Patch:
352 case spirv::Decoration::Coherent:
353 if (words.size() != 2) {
354 return emitError(unknownLoc, "OpDecorate with ")
355 << decorationName << " needs a single target <id>";
356 }
357 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
358 break;
359 case spirv::Decoration::CacheControlLoadINTEL: {
360 LogicalResult res = deserializeCacheControlDecoration<
361 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
362 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
363 "load");
364 if (failed(res))
365 return res;
366 break;
367 }
368 case spirv::Decoration::CacheControlStoreINTEL: {
369 LogicalResult res = deserializeCacheControlDecoration<
370 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
371 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
372 "store");
373 if (failed(res))
374 return res;
375 break;
376 }
377 default:
378 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
379 }
380 return success();
381}
382
383LogicalResult
384spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
385 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
386 if (words.size() < 3) {
387 return emitError(unknownLoc,
388 "OpMemberDecorate must have at least 3 operands");
389 }
390
391 auto decoration = static_cast<spirv::Decoration>(words[2]);
392 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
393 return emitError(unknownLoc,
394 " missing offset specification in OpMemberDecorate with "
395 "Offset decoration");
396 }
397 ArrayRef<uint32_t> decorationOperands;
398 if (words.size() > 3) {
399 decorationOperands = words.slice(3);
400 }
401 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
402 return success();
403}
404
405LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
406 if (words.size() < 3) {
407 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
408 }
409 unsigned wordIndex = 2;
410 auto name = decodeStringLiteral(words, wordIndex);
411 if (wordIndex != words.size()) {
412 return emitError(unknownLoc,
413 "unexpected trailing words in OpMemberName instruction");
414 }
415 memberNameMap[words[0]][words[1]] = name;
416 return success();
417}
418
420 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
421 if (!decorations.contains(argID)) {
422 argAttrs[argIndex] = DictionaryAttr::get(context, {});
423 return success();
424 }
425
426 spirv::DecorationAttr foundDecorationAttr;
427 for (NamedAttribute decAttr : decorations[argID]) {
428 for (auto decoration :
429 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
430 spirv::Decoration::AliasedPointer,
431 spirv::Decoration::RestrictPointer}) {
432
433 if (decAttr.getName() !=
434 getSymbolDecoration(stringifyDecoration(decoration)))
435 continue;
436
437 if (foundDecorationAttr)
438 return emitError(unknownLoc,
439 "more than one Aliased/Restrict decorations for "
440 "function argument with result <id> ")
441 << argID;
442
443 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
444 break;
445 }
446
447 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
448 spirv::Decoration::RelaxedPrecision))) {
449 // TODO: Current implementation supports only one decoration per function
450 // parameter so RelaxedPrecision cannot be applied at the same time as,
451 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
452 // combination of decoration allowed by the spec to be supported.
453 if (foundDecorationAttr)
454 return emitError(unknownLoc, "already found a decoration for function "
455 "argument with result <id> ")
456 << argID;
457
458 foundDecorationAttr = spirv::DecorationAttr::get(
459 context, spirv::Decoration::RelaxedPrecision);
460 }
461 }
462
463 if (!foundDecorationAttr)
464 return emitError(unknownLoc, "unimplemented decoration support for "
465 "function argument with result <id> ")
466 << argID;
467
468 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
469 foundDecorationAttr);
470 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
471 return success();
472}
473
474LogicalResult
476 if (curFunction) {
477 return emitError(unknownLoc, "found function inside function");
478 }
479
480 // Get the result type
481 if (operands.size() != 4) {
482 return emitError(unknownLoc, "OpFunction must have 4 parameters");
483 }
484 Type resultType = getType(operands[0]);
485 if (!resultType) {
486 return emitError(unknownLoc, "undefined result type from <id> ")
487 << operands[0];
488 }
489
490 uint32_t fnID = operands[1];
491 if (funcMap.count(fnID)) {
492 return emitError(unknownLoc, "duplicate function definition/declaration");
493 }
494
495 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
496 if (!fnControl) {
497 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
498 }
499
500 Type fnType = getType(operands[3]);
501 if (!fnType || !isa<FunctionType>(fnType)) {
502 return emitError(unknownLoc, "unknown function type from <id> ")
503 << operands[3];
504 }
505 auto functionType = cast<FunctionType>(fnType);
506
507 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
508 (functionType.getNumResults() == 1 &&
509 functionType.getResult(0) != resultType)) {
510 return emitError(unknownLoc, "mismatch in function type ")
511 << functionType << " and return type " << resultType << " specified";
512 }
513
514 std::string fnName = getFunctionSymbol(fnID);
515 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
516 functionType, fnControl.value());
517 // Processing other function attributes.
518 if (decorations.count(fnID)) {
519 for (auto attr : decorations[fnID].getAttrs()) {
520 funcOp->setAttr(attr.getName(), attr.getValue());
521 }
522 }
523 curFunction = funcMap[fnID] = funcOp;
524 auto *entryBlock = funcOp.addEntryBlock();
525 LLVM_DEBUG({
526 logger.startLine()
527 << "//===-------------------------------------------===//\n";
528 logger.startLine() << "[fn] name: " << fnName << "\n";
529 logger.startLine() << "[fn] type: " << fnType << "\n";
530 logger.startLine() << "[fn] ID: " << fnID << "\n";
531 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
532 logger.indent();
533 });
534
535 SmallVector<Attribute> argAttrs;
536 argAttrs.resize(functionType.getNumInputs());
537
538 // Parse the op argument instructions
539 if (functionType.getNumInputs()) {
540 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
541 auto argType = functionType.getInput(i);
542 spirv::Opcode opcode = spirv::Opcode::OpNop;
543 ArrayRef<uint32_t> operands;
544 if (failed(sliceInstruction(opcode, operands,
545 spirv::Opcode::OpFunctionParameter))) {
546 return failure();
547 }
548 if (opcode != spirv::Opcode::OpFunctionParameter) {
549 return emitError(
550 unknownLoc,
551 "missing OpFunctionParameter instruction for argument ")
552 << i;
553 }
554 if (operands.size() != 2) {
555 return emitError(
556 unknownLoc,
557 "expected result type and result <id> for OpFunctionParameter");
558 }
559 auto argDefinedType = getType(operands[0]);
560 if (!argDefinedType || argDefinedType != argType) {
561 return emitError(unknownLoc,
562 "mismatch in argument type between function type "
563 "definition ")
564 << functionType << " and argument type definition "
565 << argDefinedType << " at argument " << i;
566 }
567 if (getValue(operands[1])) {
568 return emitError(unknownLoc, "duplicate definition of result <id> ")
569 << operands[1];
570 }
571 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
572 return failure();
573 }
574
575 auto argValue = funcOp.getArgument(i);
576 valueMap[operands[1]] = argValue;
577 }
578 }
579
580 if (llvm::any_of(argAttrs, [](Attribute attr) {
581 auto argAttr = cast<DictionaryAttr>(attr);
582 return !argAttr.empty();
583 }))
584 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
585
586 // entryBlock is needed to access the arguments, Once that is done, we can
587 // erase the block for functions with 'Import' LinkageAttributes, since these
588 // are essentially function declarations, so they have no body.
589 auto linkageAttr = funcOp.getLinkageAttributes();
590 auto hasImportLinkage =
591 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
592 spirv::LinkageType::Import);
593 if (hasImportLinkage)
594 funcOp.eraseBody();
595
596 // RAII guard to reset the insertion point to the module's region after
597 // deserializing the body of this function.
598 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
599
600 spirv::Opcode opcode = spirv::Opcode::OpNop;
601 ArrayRef<uint32_t> instOperands;
602
603 // Special handling for the entry block. We need to make sure it starts with
604 // an OpLabel instruction. The entry block takes the same parameters as the
605 // function. All other blocks do not take any parameter. We have already
606 // created the entry block, here we need to register it to the correct label
607 // <id>.
608 if (failed(sliceInstruction(opcode, instOperands,
609 spirv::Opcode::OpFunctionEnd))) {
610 return failure();
611 }
612 if (opcode == spirv::Opcode::OpFunctionEnd) {
613 return processFunctionEnd(instOperands);
614 }
615 if (opcode != spirv::Opcode::OpLabel) {
616 return emitError(unknownLoc, "a basic block must start with OpLabel");
617 }
618 if (instOperands.size() != 1) {
619 return emitError(unknownLoc, "OpLabel should only have result <id>");
620 }
621 blockMap[instOperands[0]] = entryBlock;
622 if (failed(processLabel(instOperands))) {
623 return failure();
624 }
625
626 // Then process all the other instructions in the function until we hit
627 // OpFunctionEnd.
628 while (succeeded(sliceInstruction(opcode, instOperands,
629 spirv::Opcode::OpFunctionEnd)) &&
630 opcode != spirv::Opcode::OpFunctionEnd) {
631 if (failed(processInstruction(opcode, instOperands))) {
632 return failure();
633 }
634 }
635 if (opcode != spirv::Opcode::OpFunctionEnd) {
636 return failure();
637 }
638
639 return processFunctionEnd(instOperands);
640}
641
642LogicalResult
644 // Process OpFunctionEnd.
645 if (!operands.empty()) {
646 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
647 }
648
649 // Wire up block arguments from OpPhi instructions.
650 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
651 // ops.
652 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
653 return failure();
654 }
655
656 curBlock = nullptr;
657 curFunction = std::nullopt;
658
659 LLVM_DEBUG({
660 logger.unindent();
661 logger.startLine()
662 << "//===-------------------------------------------===//\n";
663 });
664 return success();
665}
666
667LogicalResult
669 if (operands.size() < 2) {
670 return emitError(unknownLoc,
671 "missing graph defintion in OpGraphEntryPointARM");
672 }
673
674 unsigned wordIndex = 0;
675 uint32_t graphID = operands[wordIndex++];
676 if (!graphMap.contains(graphID)) {
677 return emitError(unknownLoc,
678 "missing graph definition/declaration with id ")
679 << graphID;
680 }
681
682 spirv::GraphARMOp graphARM = graphMap[graphID];
683 StringRef name = decodeStringLiteral(operands, wordIndex);
684 graphARM.setSymName(name);
685 graphARM.setEntryPoint(true);
686
688 for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
689 if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
690 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
691 } else {
692 return emitError(unknownLoc, "undefined result <id> ")
693 << operands[wordIndex] << " while decoding OpGraphEntryPoint";
694 }
695 }
696
697 // RAII guard to reset the insertion point to previous value when done.
698 OpBuilder::InsertionGuard insertionGuard(opBuilder);
699 opBuilder.setInsertionPoint(graphARM);
700 spirv::GraphEntryPointARMOp::create(
701 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
702 opBuilder.getArrayAttr(interface));
703
704 return success();
705}
706
707LogicalResult
709 if (curGraph) {
710 return emitError(unknownLoc, "found graph inside graph");
711 }
712 // Get the result type.
713 if (operands.size() < 2) {
714 return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
715 }
716
717 Type type = getType(operands[0]);
718 if (!type || !isa<GraphType>(type)) {
719 return emitError(unknownLoc, "unknown graph type from <id> ")
720 << operands[0];
721 }
722 auto graphType = cast<GraphType>(type);
723 if (graphType.getNumResults() <= 0) {
724 return emitError(unknownLoc, "expected at least one result");
725 }
726
727 uint32_t graphID = operands[1];
728 if (graphMap.count(graphID)) {
729 return emitError(unknownLoc, "duplicate graph definition/declaration");
730 }
731
732 std::string graphName = getGraphSymbol(graphID);
733 auto graphOp =
734 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
735 curGraph = graphMap[graphID] = graphOp;
736 Block *entryBlock = graphOp.addEntryBlock();
737 LLVM_DEBUG({
738 logger.startLine()
739 << "//===-------------------------------------------===//\n";
740 logger.startLine() << "[graph] name: " << graphName << "\n";
741 logger.startLine() << "[graph] type: " << graphType << "\n";
742 logger.startLine() << "[graph] ID: " << graphID << "\n";
743 logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
744 logger.indent();
745 });
746
747 // Parse the op argument instructions.
748 for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
749 spirv::Opcode opcode;
750 ArrayRef<uint32_t> operands;
751 if (failed(sliceInstruction(opcode, operands,
752 spirv::Opcode::OpGraphInputARM))) {
753 return failure();
754 }
755 if (operands.size() != 3) {
756 return emitError(unknownLoc, "expected result type, result <id> and "
757 "input index for OpGraphInputARM");
758 }
759
760 Type argDefinedType = getType(operands[0]);
761 if (!argDefinedType) {
762 return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
763 }
764
765 if (argDefinedType != argType) {
766 return emitError(unknownLoc,
767 "mismatch in argument type between graph type "
768 "definition ")
769 << graphType << " and argument type definition " << argDefinedType
770 << " at argument " << index;
771 }
772 if (getValue(operands[1])) {
773 return emitError(unknownLoc, "duplicate definition of result <id> ")
774 << operands[1];
775 }
776
777 IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
778 if (!inputIndexAttr) {
779 return emitError(unknownLoc,
780 "unable to read inputIndex value from constant op ")
781 << operands[2];
782 }
783 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
784 valueMap[operands[1]] = argValue;
785 }
786
787 graphOutputs.resize(graphType.getNumResults());
788
789 // RAII guard to reset the insertion point to the module's region after
790 // deserializing the body of this function.
791 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
792
793 blockMap[graphID] = entryBlock;
794 if (failed(createGraphBlock(graphID))) {
795 return failure();
796 }
797
798 // Process all the instructions in the graph until and including
799 // OpGraphEndARM.
800 spirv::Opcode opcode;
801 ArrayRef<uint32_t> instOperands;
802 do {
803 if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
804 return failure();
805 }
806
807 if (failed(processInstruction(opcode, instOperands))) {
808 return failure();
809 }
810 } while (opcode != spirv::Opcode::OpGraphEndARM);
811
812 return success();
813}
814
815LogicalResult
817 if (operands.size() != 2) {
818 return emitError(
819 unknownLoc,
820 "expected value id and output index for OpGraphSetOutputARM");
821 }
822
823 uint32_t id = operands[0];
824 Value value = getValue(id);
825 if (!value) {
826 return emitError(unknownLoc, "could not find result <id> ") << id;
827 }
828
829 IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
830 if (!outputIndexAttr) {
831 return emitError(unknownLoc,
832 "unable to read outputIndex value from constant op ")
833 << operands[1];
834 }
835 graphOutputs[outputIndexAttr.getInt()] = value;
836 return success();
837}
838
839LogicalResult
841 // Create GraphOutputsARM instruction.
842 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
843
844 // Process OpGraphEndARM.
845 if (!operands.empty()) {
846 return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
847 }
848
849 curBlock = nullptr;
850 curGraph = std::nullopt;
851 graphOutputs.clear();
852
853 LLVM_DEBUG({
854 logger.unindent();
855 logger.startLine()
856 << "//===-------------------------------------------===//\n";
857 });
858 return success();
859}
860
861std::optional<std::pair<Attribute, Type>>
863 auto constIt = constantMap.find(id);
864 if (constIt == constantMap.end())
865 return std::nullopt;
866 return constIt->getSecond();
867}
868
869std::optional<std::pair<Attribute, Type>>
871 if (auto it = constantCompositeReplicateMap.find(id);
872 it != constantCompositeReplicateMap.end())
873 return it->second;
874 return std::nullopt;
875}
876
877std::optional<spirv::SpecConstOperationMaterializationInfo>
879 auto constIt = specConstOperationMap.find(id);
880 if (constIt == specConstOperationMap.end())
881 return std::nullopt;
882 return constIt->getSecond();
883}
884
886 auto funcName = nameMap.lookup(id).str();
887 if (funcName.empty()) {
888 funcName = "spirv_fn_" + std::to_string(id);
889 }
890 return funcName;
891}
892
893std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
894 std::string graphName = nameMap.lookup(id).str();
895 if (graphName.empty()) {
896 graphName = "spirv_graph_" + std::to_string(id);
897 }
898 return graphName;
899}
900
902 auto constName = nameMap.lookup(id).str();
903 if (constName.empty()) {
904 constName = "spirv_spec_const_" + std::to_string(id);
905 }
906 return constName;
907}
908
909spirv::SpecConstantOp
911 TypedAttr defaultValue) {
912 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
913 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
914 defaultValue);
915 if (decorations.count(resultID)) {
916 for (auto attr : decorations[resultID].getAttrs())
917 op->setAttr(attr.getName(), attr.getValue());
918 }
919 specConstMap[resultID] = op;
920 return op;
921}
922
923std::optional<spirv::GraphConstantARMOpMaterializationInfo>
925 auto graphConstIt = graphConstantMap.find(id);
926 if (graphConstIt == graphConstantMap.end())
927 return std::nullopt;
928 return graphConstIt->getSecond();
929}
930
931LogicalResult
933 unsigned wordIndex = 0;
934 if (operands.size() < 3) {
935 return emitError(
936 unknownLoc,
937 "OpVariable needs at least 3 operands, type, <id> and storage class");
938 }
939
940 // Result Type.
941 auto type = getType(operands[wordIndex]);
942 if (!type) {
943 return emitError(unknownLoc, "unknown result type <id> : ")
944 << operands[wordIndex];
945 }
946 auto ptrType = dyn_cast<spirv::PointerType>(type);
947 if (!ptrType) {
948 return emitError(unknownLoc,
949 "expected a result type <id> to be a spirv.ptr, found : ")
950 << type;
951 }
952 wordIndex++;
953
954 // Result <id>.
955 auto variableID = operands[wordIndex];
956 auto variableName = nameMap.lookup(variableID).str();
957 if (variableName.empty()) {
958 variableName = "spirv_var_" + std::to_string(variableID);
959 }
960 wordIndex++;
961
962 // Storage class.
963 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
964 if (ptrType.getStorageClass() != storageClass) {
965 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
966 << type << " and that specified in OpVariable instruction : "
967 << stringifyStorageClass(storageClass);
968 }
969 wordIndex++;
970
971 // Initializer.
972 FlatSymbolRefAttr initializer = nullptr;
973
974 if (wordIndex < operands.size()) {
975 Operation *op = nullptr;
976
977 if (auto initOp = getGlobalVariable(operands[wordIndex]))
978 op = initOp;
979 else if (auto initOp = getSpecConstant(operands[wordIndex]))
980 op = initOp;
981 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
982 op = initOp;
983 else
984 return emitError(unknownLoc, "unknown <id> ")
985 << operands[wordIndex] << "used as initializer";
986
987 initializer = SymbolRefAttr::get(op);
988 wordIndex++;
989 }
990 if (wordIndex != operands.size()) {
991 return emitError(unknownLoc,
992 "found more operands than expected when deserializing "
993 "OpVariable instruction, only ")
994 << wordIndex << " of " << operands.size() << " processed";
995 }
996 auto loc = createFileLineColLoc(opBuilder);
997 auto varOp = spirv::GlobalVariableOp::create(
998 opBuilder, loc, TypeAttr::get(type),
999 opBuilder.getStringAttr(variableName), initializer);
1000
1001 // Decorations.
1002 if (decorations.count(variableID)) {
1003 for (auto attr : decorations[variableID].getAttrs())
1004 varOp->setAttr(attr.getName(), attr.getValue());
1005 }
1006 globalVariableMap[variableID] = varOp;
1007 return success();
1008}
1009
1010IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1011 auto constInfo = getConstant(id);
1012 if (!constInfo) {
1013 return nullptr;
1014 }
1015 return dyn_cast<IntegerAttr>(constInfo->first);
1016}
1017
1018LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1019 if (operands.size() < 2) {
1020 return emitError(unknownLoc, "OpName needs at least 2 operands");
1021 }
1022 if (!nameMap.lookup(operands[0]).empty()) {
1023 return emitError(unknownLoc, "duplicate name found for result <id> ")
1024 << operands[0];
1025 }
1026 unsigned wordIndex = 1;
1027 StringRef name = decodeStringLiteral(operands, wordIndex);
1028 if (wordIndex != operands.size()) {
1029 return emitError(unknownLoc,
1030 "unexpected trailing words in OpName instruction");
1031 }
1032 nameMap[operands[0]] = name;
1033 return success();
1034}
1035
1036//===----------------------------------------------------------------------===//
1037// Type
1038//===----------------------------------------------------------------------===//
1039
1040LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1041 ArrayRef<uint32_t> operands) {
1042 if (operands.empty()) {
1043 return emitError(unknownLoc, "type instruction with opcode ")
1044 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1045 }
1046
1047 /// TODO: Types might be forward declared in some instructions and need to be
1048 /// handled appropriately.
1049 if (typeMap.count(operands[0])) {
1050 return emitError(unknownLoc, "duplicate definition for result <id> ")
1051 << operands[0];
1052 }
1053
1054 switch (opcode) {
1055 case spirv::Opcode::OpTypeVoid:
1056 if (operands.size() != 1)
1057 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1058 typeMap[operands[0]] = opBuilder.getNoneType();
1059 break;
1060 case spirv::Opcode::OpTypeBool:
1061 if (operands.size() != 1)
1062 return emitError(unknownLoc, "OpTypeBool must have no parameters");
1063 typeMap[operands[0]] = opBuilder.getI1Type();
1064 break;
1065 case spirv::Opcode::OpTypeInt: {
1066 if (operands.size() != 3)
1067 return emitError(
1068 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1069
1070 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1071 // to preserve or validate.
1072 // 0 indicates unsigned, or no signedness semantics
1073 // 1 indicates signed semantics."
1074 //
1075 // So we cannot differentiate signless and unsigned integers; always use
1076 // signless semantics for such cases.
1077 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1078 : IntegerType::SignednessSemantics::Signless;
1079 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1080 } break;
1081 case spirv::Opcode::OpTypeFloat: {
1082 if (operands.size() != 2 && operands.size() != 3)
1083 return emitError(unknownLoc,
1084 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1085 "or 3 operands (type, bitwidth, encoding), but got ")
1086 << operands.size();
1087 uint32_t bitWidth = operands[1];
1088
1089 Type floatTy;
1090 if (operands.size() == 2) {
1091 switch (bitWidth) {
1092 case 16:
1093 floatTy = opBuilder.getF16Type();
1094 break;
1095 case 32:
1096 floatTy = opBuilder.getF32Type();
1097 break;
1098 case 64:
1099 floatTy = opBuilder.getF64Type();
1100 break;
1101 default:
1102 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1103 << bitWidth;
1104 }
1105 }
1106
1107 if (operands.size() == 3) {
1108 if (spirv::FPEncoding(operands[2]) == spirv::FPEncoding::BFloat16KHR &&
1109 bitWidth == 16)
1110 floatTy = opBuilder.getBF16Type();
1111 else if (spirv::FPEncoding(operands[2]) ==
1112 spirv::FPEncoding::Float8E4M3EXT &&
1113 bitWidth == 8)
1114 floatTy = opBuilder.getF8E4M3FNType();
1115 else if (spirv::FPEncoding(operands[2]) ==
1116 spirv::FPEncoding::Float8E5M2EXT &&
1117 bitWidth == 8)
1118 floatTy = opBuilder.getF8E5M2Type();
1119 else
1120 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1121 << operands[2] << " and bitWidth " << bitWidth;
1122 }
1123
1124 typeMap[operands[0]] = floatTy;
1125 } break;
1126 case spirv::Opcode::OpTypeVector: {
1127 if (operands.size() != 3) {
1128 return emitError(
1129 unknownLoc,
1130 "OpTypeVector must have element type and count parameters");
1131 }
1132 Type elementTy = getType(operands[1]);
1133 if (!elementTy) {
1134 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1135 << operands[1];
1136 }
1137 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1138 } break;
1139 case spirv::Opcode::OpTypePointer: {
1140 return processOpTypePointer(operands);
1141 } break;
1142 case spirv::Opcode::OpTypeArray:
1143 return processArrayType(operands);
1144 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1145 return processCooperativeMatrixTypeKHR(operands);
1146 case spirv::Opcode::OpTypeFunction:
1147 return processFunctionType(operands);
1148 case spirv::Opcode::OpTypeImage:
1149 return processImageType(operands);
1150 case spirv::Opcode::OpTypeSampledImage:
1151 return processSampledImageType(operands);
1152 case spirv::Opcode::OpTypeRuntimeArray:
1153 return processRuntimeArrayType(operands);
1154 case spirv::Opcode::OpTypeStruct:
1155 return processStructType(operands);
1156 case spirv::Opcode::OpTypeMatrix:
1157 return processMatrixType(operands);
1158 case spirv::Opcode::OpTypeTensorARM:
1159 return processTensorARMType(operands);
1160 case spirv::Opcode::OpTypeGraphARM:
1161 return processGraphTypeARM(operands);
1162 default:
1163 return emitError(unknownLoc, "unhandled type instruction");
1164 }
1165 return success();
1166}
1167
1168LogicalResult
1170 if (operands.size() != 3)
1171 return emitError(unknownLoc, "OpTypePointer must have two parameters");
1172
1173 auto pointeeType = getType(operands[2]);
1174 if (!pointeeType)
1175 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1176 << operands[2];
1177
1178 uint32_t typePointerID = operands[0];
1179 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1180 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1181
1182 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1183 deferredStructIt != std::end(deferredStructTypesInfos);) {
1184 for (auto *unresolvedMemberIt =
1185 std::begin(deferredStructIt->unresolvedMemberTypes);
1186 unresolvedMemberIt !=
1187 std::end(deferredStructIt->unresolvedMemberTypes);) {
1188 if (unresolvedMemberIt->first == typePointerID) {
1189 // The newly constructed pointer type can resolve one of the
1190 // deferred struct type members; update the memberTypes list and
1191 // clean the unresolvedMemberTypes list accordingly.
1192 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1193 typeMap[typePointerID];
1194 unresolvedMemberIt =
1195 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1196 } else {
1197 ++unresolvedMemberIt;
1198 }
1199 }
1200
1201 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1202 // All deferred struct type members are now resolved, set the struct body.
1203 auto structType = deferredStructIt->deferredStructType;
1204
1205 assert(structType && "expected a spirv::StructType");
1206 assert(structType.isIdentified() && "expected an indentified struct");
1207
1208 if (failed(structType.trySetBody(
1209 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1210 deferredStructIt->memberDecorationsInfo,
1211 deferredStructIt->structDecorationsInfo)))
1212 return failure();
1213
1214 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1215 } else {
1216 ++deferredStructIt;
1217 }
1218 }
1219
1220 return success();
1221}
1222
1223LogicalResult
1225 if (operands.size() != 3) {
1226 return emitError(unknownLoc,
1227 "OpTypeArray must have element type and count parameters");
1228 }
1229
1230 Type elementTy = getType(operands[1]);
1231 if (!elementTy) {
1232 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1233 << operands[1];
1234 }
1235
1236 unsigned count = 0;
1237 // TODO: The count can also come frome a specialization constant.
1238 auto countInfo = getConstant(operands[2]);
1239 if (!countInfo) {
1240 return emitError(unknownLoc, "OpTypeArray count <id> ")
1241 << operands[2] << "can only come from normal constant right now";
1242 }
1243
1244 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1245 count = intVal.getValue().getZExtValue();
1246 } else {
1247 return emitError(unknownLoc, "OpTypeArray count must come from a "
1248 "scalar integer constant instruction");
1249 }
1250
1251 typeMap[operands[0]] = spirv::ArrayType::get(
1252 elementTy, count, typeDecorations.lookup(operands[0]));
1253 return success();
1254}
1255
1256LogicalResult
1258 assert(!operands.empty() && "No operands for processing function type");
1259 if (operands.size() == 1) {
1260 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1261 }
1262 auto returnType = getType(operands[1]);
1263 if (!returnType) {
1264 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1265 }
1266 SmallVector<Type, 1> argTypes;
1267 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1268 auto ty = getType(operands[i]);
1269 if (!ty) {
1270 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1271 }
1272 argTypes.push_back(ty);
1273 }
1274 ArrayRef<Type> returnTypes;
1275 if (!isVoidType(returnType)) {
1276 returnTypes = llvm::ArrayRef(returnType);
1277 }
1278 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1279 return success();
1280}
1281
1283 ArrayRef<uint32_t> operands) {
1284 if (operands.size() != 6) {
1285 return emitError(unknownLoc,
1286 "OpTypeCooperativeMatrixKHR must have element type, "
1287 "scope, row and column parameters, and use");
1288 }
1289
1290 Type elementTy = getType(operands[1]);
1291 if (!elementTy) {
1292 return emitError(unknownLoc,
1293 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1294 << operands[1];
1295 }
1296
1297 std::optional<spirv::Scope> scope =
1298 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1299 if (!scope) {
1300 return emitError(
1301 unknownLoc,
1302 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1303 << operands[2];
1304 }
1305
1306 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1307 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1308 IntegerAttr useAttr = getConstantInt(operands[5]);
1309
1310 if (!rowsAttr)
1311 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1312 "undefined constant <id> ")
1313 << operands[3];
1314
1315 if (!columnsAttr)
1316 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1317 "references undefined constant <id> ")
1318 << operands[4];
1319
1320 if (!useAttr)
1321 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1322 "undefined constant <id> ")
1323 << operands[5];
1324
1325 unsigned rows = rowsAttr.getInt();
1326 unsigned columns = columnsAttr.getInt();
1327
1328 std::optional<spirv::CooperativeMatrixUseKHR> use =
1329 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1330 if (!use) {
1331 return emitError(
1332 unknownLoc,
1333 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1334 << operands[5];
1335 }
1336
1337 typeMap[operands[0]] =
1338 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1339 return success();
1340}
1341
1342LogicalResult
1344 if (operands.size() != 2) {
1345 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1346 }
1347 Type memberType = getType(operands[1]);
1348 if (!memberType) {
1349 return emitError(unknownLoc,
1350 "OpTypeRuntimeArray references undefined <id> ")
1351 << operands[1];
1352 }
1353 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1354 memberType, typeDecorations.lookup(operands[0]));
1355 return success();
1356}
1357
1358LogicalResult
1360 // TODO: Find a way to handle identified structs when debug info is stripped.
1361
1362 if (operands.empty()) {
1363 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1364 }
1365
1366 if (operands.size() == 1) {
1367 // Handle empty struct.
1368 typeMap[operands[0]] =
1369 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1370 return success();
1371 }
1372
1373 // First element is operand ID, second element is member index in the struct.
1374 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1375 SmallVector<Type, 4> memberTypes;
1376
1377 for (auto op : llvm::drop_begin(operands, 1)) {
1378 Type memberType = getType(op);
1379 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1380
1381 if (!memberType && !typeForwardPtr)
1382 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1383 << op;
1384
1385 if (!memberType)
1386 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1387
1388 memberTypes.push_back(memberType);
1389 }
1390
1393 if (memberDecorationMap.count(operands[0])) {
1394 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1395 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1396 if (allMemberDecorations.count(memberIndex)) {
1397 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1398 // Check for offset.
1399 if (memberDecoration.first == spirv::Decoration::Offset) {
1400 // If offset info is empty, resize to the number of members;
1401 if (offsetInfo.empty()) {
1402 offsetInfo.resize(memberTypes.size());
1403 }
1404 offsetInfo[memberIndex] = memberDecoration.second[0];
1405 } else {
1406 auto intType = mlir::IntegerType::get(context, 32);
1407 if (!memberDecoration.second.empty()) {
1408 memberDecorationsInfo.emplace_back(
1409 memberIndex, memberDecoration.first,
1410 IntegerAttr::get(intType, memberDecoration.second[0]));
1411 } else {
1412 memberDecorationsInfo.emplace_back(
1413 memberIndex, memberDecoration.first, UnitAttr::get(context));
1414 }
1415 }
1416 }
1417 }
1418 }
1419 }
1420
1422 if (decorations.count(operands[0])) {
1423 NamedAttrList &allDecorations = decorations[operands[0]];
1424 for (NamedAttribute &decorationAttr : allDecorations) {
1425 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1426 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1427 assert(decoration.has_value());
1428 structDecorationsInfo.emplace_back(decoration.value(),
1429 decorationAttr.getValue());
1430 }
1431 }
1432
1433 uint32_t structID = operands[0];
1434 std::string structIdentifier = nameMap.lookup(structID).str();
1435
1436 if (structIdentifier.empty()) {
1437 assert(unresolvedMemberTypes.empty() &&
1438 "didn't expect unresolved member types");
1439 typeMap[structID] = spirv::StructType::get(
1440 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1441 } else {
1442 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1443 typeMap[structID] = structTy;
1444
1445 if (!unresolvedMemberTypes.empty())
1446 deferredStructTypesInfos.push_back(
1447 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1448 memberDecorationsInfo, structDecorationsInfo});
1449 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1450 memberDecorationsInfo,
1451 structDecorationsInfo)))
1452 return failure();
1453 }
1454
1455 // TODO: Update StructType to have member name as attribute as
1456 // well.
1457 return success();
1458}
1459
1460LogicalResult
1462 if (operands.size() != 3) {
1463 // Three operands are needed: result_id, column_type, and column_count
1464 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1465 " (result_id, column_type, and column_count)");
1466 }
1467 // Matrix columns must be of vector type
1468 Type elementTy = getType(operands[1]);
1469 if (!elementTy) {
1470 return emitError(unknownLoc,
1471 "OpTypeMatrix references undefined column type.")
1472 << operands[1];
1473 }
1474
1475 uint32_t colsCount = operands[2];
1476 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1477 return success();
1478}
1479
1480LogicalResult
1482 unsigned size = operands.size();
1483 if (size < 2 || size > 4)
1484 return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1485 "(result_id, element_type, (rank), (shape)) ")
1486 << size;
1487
1488 Type elementTy = getType(operands[1]);
1489 if (!elementTy)
1490 return emitError(unknownLoc,
1491 "OpTypeTensorARM references undefined element type ")
1492 << operands[1];
1493
1494 if (size == 2) {
1495 typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1496 return success();
1497 }
1498
1499 IntegerAttr rankAttr = getConstantInt(operands[2]);
1500 if (!rankAttr)
1501 return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1502 "scalar integer constant instruction");
1503 unsigned rank = rankAttr.getValue().getZExtValue();
1504 if (size == 3) {
1505 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1506 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1507 return success();
1508 }
1509
1510 std::optional<std::pair<Attribute, Type>> shapeInfo =
1511 getConstant(operands[3]);
1512 if (!shapeInfo)
1513 return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1514 "constant instruction of type OpTypeArray");
1515
1516 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1518 for (auto dimAttr : shapeArrayAttr.getValue()) {
1519 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1520 if (!dimIntAttr)
1521 return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1522 "dimension size");
1523 shape.push_back(dimIntAttr.getValue().getSExtValue());
1524 }
1525 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1526 return success();
1527}
1528
1529LogicalResult
1531 unsigned size = operands.size();
1532 if (size < 2) {
1533 return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1534 "(result_id, num_inputs, (inout0_type, "
1535 "inout1_type, ...))")
1536 << size;
1537 }
1538 uint32_t numInputs = operands[1];
1539 SmallVector<Type, 1> argTypes;
1540 SmallVector<Type, 1> returnTypes;
1541 for (unsigned i = 2; i < size; ++i) {
1542 Type inOutTy = getType(operands[i]);
1543 if (!inOutTy) {
1544 return emitError(unknownLoc,
1545 "OpTypeGraphARM references undefined element type.")
1546 << operands[i];
1547 }
1548 if (i - 2 >= numInputs) {
1549 returnTypes.push_back(inOutTy);
1550 } else {
1551 argTypes.push_back(inOutTy);
1552 }
1553 }
1554 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1555 return success();
1556}
1557
1558LogicalResult
1560 if (operands.size() != 2)
1561 return emitError(unknownLoc,
1562 "OpTypeForwardPointer instruction must have two operands");
1563
1564 typeForwardPointerIDs.insert(operands[0]);
1565 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1566 // instruction that defines the actual type.
1567
1568 return success();
1569}
1570
1571LogicalResult
1573 // TODO: Add support for Access Qualifier.
1574 if (operands.size() != 8)
1575 return emitError(
1576 unknownLoc,
1577 "OpTypeImage with non-eight operands are not supported yet");
1578
1579 Type elementTy = getType(operands[1]);
1580 if (!elementTy)
1581 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1582 << operands[1];
1583
1584 auto dim = spirv::symbolizeDim(operands[2]);
1585 if (!dim)
1586 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1587 << operands[2];
1588
1589 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1590 if (!depthInfo)
1591 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1592 << operands[3];
1593
1594 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1595 if (!arrayedInfo)
1596 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1597 << operands[4];
1598
1599 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1600 if (!samplingInfo)
1601 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1602
1603 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1604 if (!samplerUseInfo)
1605 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1606 << operands[6];
1607
1608 auto format = spirv::symbolizeImageFormat(operands[7]);
1609 if (!format)
1610 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1611 << operands[7];
1612
1613 typeMap[operands[0]] = spirv::ImageType::get(
1614 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1615 samplingInfo.value(), samplerUseInfo.value(), format.value());
1616 return success();
1617}
1618
1619LogicalResult
1621 if (operands.size() != 2)
1622 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1623
1624 Type elementTy = getType(operands[1]);
1625 if (!elementTy)
1626 return emitError(unknownLoc,
1627 "OpTypeSampledImage references undefined <id>: ")
1628 << operands[1];
1629
1630 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1631 return success();
1632}
1633
1634//===----------------------------------------------------------------------===//
1635// Constant
1636//===----------------------------------------------------------------------===//
1637
1639 bool isSpec) {
1640 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1641
1642 if (operands.size() < 2) {
1643 return emitError(unknownLoc)
1644 << opname << " must have type <id> and result <id>";
1645 }
1646 if (operands.size() < 3) {
1647 return emitError(unknownLoc)
1648 << opname << " must have at least 1 more parameter";
1649 }
1650
1651 Type resultType = getType(operands[0]);
1652 if (!resultType) {
1653 return emitError(unknownLoc, "undefined result type from <id> ")
1654 << operands[0];
1655 }
1656
1657 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1658 if (bitwidth == 64) {
1659 if (operands.size() == 4) {
1660 return success();
1661 }
1662 return emitError(unknownLoc)
1663 << opname << " should have 2 parameters for 64-bit values";
1664 }
1665 if (bitwidth <= 32) {
1666 if (operands.size() == 3) {
1667 return success();
1668 }
1669
1670 return emitError(unknownLoc)
1671 << opname
1672 << " should have 1 parameter for values with no more than 32 bits";
1673 }
1674 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1675 << bitwidth;
1676 };
1677
1678 auto resultID = operands[1];
1679
1680 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1681 auto bitwidth = intType.getWidth();
1682 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1683 return failure();
1684 }
1685
1686 APInt value;
1687 if (bitwidth == 64) {
1688 // 64-bit integers are represented with two SPIR-V words. According to
1689 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1690 // literal’s low-order words appear first."
1691 struct DoubleWord {
1692 uint32_t word1;
1693 uint32_t word2;
1694 } words = {operands[2], operands[3]};
1695 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1696 } else if (bitwidth <= 32) {
1697 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1698 /*implicitTrunc=*/true);
1699 }
1700
1701 auto attr = opBuilder.getIntegerAttr(intType, value);
1702
1703 if (isSpec) {
1704 createSpecConstant(unknownLoc, resultID, attr);
1705 } else {
1706 // For normal constants, we just record the attribute (and its type) for
1707 // later materialization at use sites.
1708 constantMap.try_emplace(resultID, attr, intType);
1709 }
1710
1711 return success();
1712 }
1713
1714 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1715 auto bitwidth = floatType.getWidth();
1716 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1717 return failure();
1718 }
1719
1720 APFloat value(0.f);
1721 if (floatType.isF64()) {
1722 // Double values are represented with two SPIR-V words. According to
1723 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1724 // literal’s low-order words appear first."
1725 struct DoubleWord {
1726 uint32_t word1;
1727 uint32_t word2;
1728 } words = {operands[2], operands[3]};
1729 value = APFloat(llvm::bit_cast<double>(words));
1730 } else if (floatType.isF32()) {
1731 value = APFloat(llvm::bit_cast<float>(operands[2]));
1732 } else if (floatType.isF16()) {
1733 APInt data(16, operands[2]);
1734 value = APFloat(APFloat::IEEEhalf(), data);
1735 } else if (floatType.isBF16()) {
1736 APInt data(16, operands[2]);
1737 value = APFloat(APFloat::BFloat(), data);
1738 } else if (floatType.isF8E4M3FN()) {
1739 APInt data(8, operands[2]);
1740 value = APFloat(APFloat::Float8E4M3FN(), data);
1741 } else if (floatType.isF8E5M2()) {
1742 APInt data(8, operands[2]);
1743 value = APFloat(APFloat::Float8E5M2(), data);
1744 }
1745
1746 auto attr = opBuilder.getFloatAttr(floatType, value);
1747 if (isSpec) {
1748 createSpecConstant(unknownLoc, resultID, attr);
1749 } else {
1750 // For normal constants, we just record the attribute (and its type) for
1751 // later materialization at use sites.
1752 constantMap.try_emplace(resultID, attr, floatType);
1753 }
1754
1755 return success();
1756 }
1757
1758 return emitError(unknownLoc, "OpConstant can only generate values of "
1759 "scalar integer or floating-point type");
1760}
1761
1763 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1764 if (operands.size() != 2) {
1765 return emitError(unknownLoc, "Op")
1766 << (isSpec ? "Spec" : "") << "Constant"
1767 << (isTrue ? "True" : "False")
1768 << " must have type <id> and result <id>";
1769 }
1770
1771 auto attr = opBuilder.getBoolAttr(isTrue);
1772 auto resultID = operands[1];
1773 if (isSpec) {
1774 createSpecConstant(unknownLoc, resultID, attr);
1775 } else {
1776 // For normal constants, we just record the attribute (and its type) for
1777 // later materialization at use sites.
1778 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1779 }
1780
1781 return success();
1782}
1783
1784LogicalResult
1786 if (operands.size() < 2) {
1787 return emitError(unknownLoc,
1788 "OpConstantComposite must have type <id> and result <id>");
1789 }
1790 if (operands.size() < 3) {
1791 return emitError(unknownLoc,
1792 "OpConstantComposite must have at least 1 parameter");
1793 }
1794
1795 Type resultType = getType(operands[0]);
1796 if (!resultType) {
1797 return emitError(unknownLoc, "undefined result type from <id> ")
1798 << operands[0];
1799 }
1800
1802 elements.reserve(operands.size() - 2);
1803 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1804 auto elementInfo = getConstant(operands[i]);
1805 if (!elementInfo) {
1806 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1807 << operands[i] << " must come from a normal constant";
1808 }
1809 elements.push_back(elementInfo->first);
1810 }
1811
1812 auto resultID = operands[1];
1813 if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1814 SmallVector<Attribute> flattenedElems;
1815 for (Attribute element : elements) {
1816 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1817 for (auto value : denseElemAttr.getValues<Attribute>())
1818 flattenedElems.push_back(value);
1819 } else {
1820 flattenedElems.push_back(element);
1821 }
1822 }
1823 auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1824 constantMap.try_emplace(resultID, attr, tensorType);
1825 } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1826 auto attr = DenseElementsAttr::get(shapedType, elements);
1827 // For normal constants, we just record the attribute (and its type) for
1828 // later materialization at use sites.
1829 constantMap.try_emplace(resultID, attr, shapedType);
1830 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1831 auto attr = opBuilder.getArrayAttr(elements);
1832 constantMap.try_emplace(resultID, attr, resultType);
1833 } else {
1834 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1835 << resultType;
1836 }
1837
1838 return success();
1839}
1840
1842 ArrayRef<uint32_t> operands) {
1843 if (operands.size() != 3) {
1844 return emitError(
1845 unknownLoc,
1846 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1847 << operands.size();
1848 }
1849
1850 Type resultType = getType(operands[0]);
1851 if (!resultType) {
1852 return emitError(unknownLoc, "undefined result type from <id> ")
1853 << operands[0];
1854 }
1855
1856 auto compositeType = dyn_cast<CompositeType>(resultType);
1857 if (!compositeType) {
1858 return emitError(unknownLoc,
1859 "result type from <id> is not a composite type")
1860 << operands[0];
1861 }
1862
1863 uint32_t resultID = operands[1];
1864 uint32_t constantID = operands[2];
1865
1866 std::optional<std::pair<Attribute, Type>> constantInfo =
1867 getConstant(constantID);
1868 if (constantInfo.has_value()) {
1869 constantCompositeReplicateMap.try_emplace(
1870 resultID, constantInfo.value().first, resultType);
1871 return success();
1872 }
1873
1874 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1876 if (replicatedConstantCompositeInfo.has_value()) {
1877 constantCompositeReplicateMap.try_emplace(
1878 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1879 return success();
1880 }
1881
1882 return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1883 << constantID
1884 << " must come from a normal constant or a "
1885 "OpConstantCompositeReplicateEXT";
1886}
1887
1888LogicalResult
1890 if (operands.size() < 2) {
1891 return emitError(
1892 unknownLoc,
1893 "OpSpecConstantComposite must have type <id> and result <id>");
1894 }
1895 if (operands.size() < 3) {
1896 return emitError(unknownLoc,
1897 "OpSpecConstantComposite must have at least 1 parameter");
1898 }
1899
1900 Type resultType = getType(operands[0]);
1901 if (!resultType) {
1902 return emitError(unknownLoc, "undefined result type from <id> ")
1903 << operands[0];
1904 }
1905
1906 auto resultID = operands[1];
1907 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1908
1910 elements.reserve(operands.size() - 2);
1911 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1912 auto elementInfo = getSpecConstant(operands[i]);
1913 elements.push_back(SymbolRefAttr::get(elementInfo));
1914 }
1915
1916 auto op = spirv::SpecConstantCompositeOp::create(
1917 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1918 opBuilder.getArrayAttr(elements));
1919 specConstCompositeMap[resultID] = op;
1920
1921 return success();
1922}
1923
1925 ArrayRef<uint32_t> operands) {
1926 if (operands.size() != 3) {
1927 return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1928 "3 operands but found ")
1929 << operands.size();
1930 }
1931
1932 Type resultType = getType(operands[0]);
1933 if (!resultType) {
1934 return emitError(unknownLoc, "undefined result type from <id> ")
1935 << operands[0];
1936 }
1937
1938 auto compositeType = dyn_cast<CompositeType>(resultType);
1939 if (!compositeType) {
1940 return emitError(unknownLoc,
1941 "result type from <id> is not a composite type")
1942 << operands[0];
1943 }
1944
1945 uint32_t resultID = operands[1];
1946
1947 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1948 spirv::SpecConstantOp constituentSpecConstantOp =
1949 getSpecConstant(operands[2]);
1950 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1951 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1952 SymbolRefAttr::get(constituentSpecConstantOp));
1953
1954 specConstCompositeReplicateMap[resultID] = op;
1955
1956 return success();
1957}
1958
1959LogicalResult
1961 if (operands.size() < 3)
1962 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1963 "result <id>, and operand opcode");
1964
1965 uint32_t resultTypeID = operands[0];
1966
1967 if (!getType(resultTypeID))
1968 return emitError(unknownLoc, "undefined result type from <id> ")
1969 << resultTypeID;
1970
1971 uint32_t resultID = operands[1];
1972 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1973 auto emplaceResult = specConstOperationMap.try_emplace(
1974 resultID,
1976 enclosedOpcode, resultTypeID,
1977 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1978
1979 if (!emplaceResult.second)
1980 return emitError(unknownLoc, "value with <id>: ")
1981 << resultID << " is probably defined before.";
1982
1983 return success();
1984}
1985
1987 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1988 ArrayRef<uint32_t> enclosedOpOperands) {
1989
1990 Type resultType = getType(resultTypeID);
1991
1992 // Instructions wrapped by OpSpecConstantOp need an ID for their
1993 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1994 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1995 // ID in that map is assigned to the result of the enclosed instruction. Note
1996 // that there is no need to update this fake ID since we only need to
1997 // reference the created Value for the enclosed op from the spv::YieldOp
1998 // created later in this method (both of which are the only values in their
1999 // region: the SpecConstantOperation's region). If we encounter another
2000 // SpecConstantOperation in the module, we simply re-use the fake ID since the
2001 // previous Value assigned to it isn't visible in the current scope anyway.
2002 DenseMap<uint32_t, Value> newValueMap;
2003 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
2004 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
2005
2006 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
2007 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2008 enclosedOpResultTypeAndOperands.push_back(fakeID);
2009 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2010 enclosedOpOperands.end());
2011
2012 // Process enclosed instruction before creating the enclosing
2013 // specConstantOperation (and its region). This way, references to constants,
2014 // global variables, and spec constants will be materialized outside the new
2015 // op's region. For more info, see Deserializer::getValue's implementation.
2016 if (failed(
2017 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2018 return Value();
2019
2020 // Since the enclosed op is emitted in the current block, split it in a
2021 // separate new block.
2022 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2023
2024 auto loc = createFileLineColLoc(opBuilder);
2025 auto specConstOperationOp =
2026 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2027
2028 Region &body = specConstOperationOp.getBody();
2029 // Move the new block into SpecConstantOperation's body.
2030 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2031 Region::iterator(enclosedBlock));
2032 Block &block = body.back();
2033
2034 // RAII guard to reset the insertion point to the module's region after
2035 // deserializing the body of the specConstantOperation.
2036 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2037 opBuilder.setInsertionPointToEnd(&block);
2038
2039 spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2040 return specConstOperationOp.getResult();
2041}
2042
2043LogicalResult
2045 if (operands.size() != 2) {
2046 return emitError(unknownLoc,
2047 "OpConstantNull must only have type <id> and result <id>");
2048 }
2049
2050 Type resultType = getType(operands[0]);
2051 if (!resultType) {
2052 return emitError(unknownLoc, "undefined result type from <id> ")
2053 << operands[0];
2054 }
2055
2056 auto resultID = operands[1];
2057 Attribute attr;
2058 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2059 attr = opBuilder.getZeroAttr(resultType);
2060 } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2061 if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2062 attr = DenseElementsAttr::get(tensorType, element);
2063 }
2064
2065 if (attr) {
2066 // For normal constants, we just record the attribute (and its type) for
2067 // later materialization at use sites.
2068 constantMap.try_emplace(resultID, attr, resultType);
2069 return success();
2070 }
2071
2072 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2073 << resultType;
2074}
2075
2076LogicalResult
2078 if (operands.size() < 3) {
2079 return emitError(unknownLoc)
2080 << "OpGraphConstantARM must have at least 2 operands";
2081 }
2082
2083 Type resultType = getType(operands[0]);
2084 if (!resultType) {
2085 return emitError(unknownLoc, "undefined result type from <id> ")
2086 << operands[0];
2087 }
2088
2089 uint32_t resultID = operands[1];
2090
2091 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2092 return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2093 }
2094
2095 APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2096 Type i32Ty = opBuilder.getIntegerType(32);
2097 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2098 graphConstantMap.try_emplace(
2099 resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2100
2101 return success();
2102}
2103
2104//===----------------------------------------------------------------------===//
2105// Control flow
2106//===----------------------------------------------------------------------===//
2107
2109 if (auto *block = getBlock(id)) {
2110 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2111 << " @ " << block << "\n");
2112 return block;
2113 }
2114
2115 // We don't know where this block will be placed finally (in a
2116 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2117 // function for now and sort out the proper place later.
2118 auto *block = curFunction->addBlock();
2119 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2120 << " @ " << block << "\n");
2121 return blockMap[id] = block;
2122}
2123
2125 if (!curBlock) {
2126 return emitError(unknownLoc, "OpBranch must appear inside a block");
2127 }
2128
2129 if (operands.size() != 1) {
2130 return emitError(unknownLoc, "OpBranch must take exactly one target label");
2131 }
2132
2133 auto *target = getOrCreateBlock(operands[0]);
2134 auto loc = createFileLineColLoc(opBuilder);
2135 // The preceding instruction for the OpBranch instruction could be an
2136 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2137 // the same OpLine information.
2138 spirv::BranchOp::create(opBuilder, loc, target);
2139
2141 return success();
2142}
2143
2144LogicalResult
2146 if (!curBlock) {
2147 return emitError(unknownLoc,
2148 "OpBranchConditional must appear inside a block");
2149 }
2150
2151 if (operands.size() != 3 && operands.size() != 5) {
2152 return emitError(unknownLoc,
2153 "OpBranchConditional must have condition, true label, "
2154 "false label, and optionally two branch weights");
2155 }
2156
2157 auto condition = getValue(operands[0]);
2158 auto *trueBlock = getOrCreateBlock(operands[1]);
2159 auto *falseBlock = getOrCreateBlock(operands[2]);
2160
2161 std::optional<std::pair<uint32_t, uint32_t>> weights;
2162 if (operands.size() == 5) {
2163 weights = std::make_pair(operands[3], operands[4]);
2164 }
2165 // The preceding instruction for the OpBranchConditional instruction could be
2166 // an OpSelectionMerge instruction, in this case they will have the same
2167 // OpLine information.
2168 auto loc = createFileLineColLoc(opBuilder);
2169 spirv::BranchConditionalOp::create(
2170 opBuilder, loc, condition, trueBlock,
2171 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2172 /*falseArguments=*/ArrayRef<Value>(), weights);
2173
2175 return success();
2176}
2177
2179 if (!curFunction) {
2180 return emitError(unknownLoc, "OpLabel must appear inside a function");
2181 }
2182
2183 if (operands.size() != 1) {
2184 return emitError(unknownLoc, "OpLabel should only have result <id>");
2185 }
2186
2187 auto labelID = operands[0];
2188 // We may have forward declared this block.
2189 auto *block = getOrCreateBlock(labelID);
2190 LLVM_DEBUG(logger.startLine()
2191 << "[block] populating block " << block << "\n");
2192 // If we have seen this block, make sure it was just a forward declaration.
2193 assert(block->empty() && "re-deserialize the same block!");
2194
2195 opBuilder.setInsertionPointToStart(block);
2196 blockMap[labelID] = curBlock = block;
2197
2198 return success();
2199}
2200
2201LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2202 if (!curGraph) {
2203 return emitError(unknownLoc, "a graph block must appear inside a graph");
2204 }
2205
2206 // We may have forward declared this block.
2207 Block *block = getOrCreateBlock(graphID);
2208 LLVM_DEBUG(logger.startLine()
2209 << "[block] populating block " << block << "\n");
2210 // If we have seen this block, make sure it was just a forward declaration.
2211 assert(block->empty() && "re-deserialize the same block!");
2212
2213 opBuilder.setInsertionPointToStart(block);
2214 blockMap[graphID] = curBlock = block;
2215
2216 return success();
2217}
2218
2219LogicalResult
2221 if (!curBlock) {
2222 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2223 }
2224
2225 if (operands.size() < 2) {
2226 return emitError(
2227 unknownLoc,
2228 "OpSelectionMerge must specify merge target and selection control");
2229 }
2230
2231 auto *mergeBlock = getOrCreateBlock(operands[0]);
2232 auto loc = createFileLineColLoc(opBuilder);
2233 auto selectionControl = operands[1];
2234
2235 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2236 .second) {
2237 return emitError(
2238 unknownLoc,
2239 "a block cannot have more than one OpSelectionMerge instruction");
2240 }
2241
2242 return success();
2243}
2244
2245LogicalResult
2247 if (!curBlock) {
2248 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2249 }
2250
2251 if (operands.size() < 3) {
2252 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2253 "continue target and loop control");
2254 }
2255
2256 auto *mergeBlock = getOrCreateBlock(operands[0]);
2257 auto *continueBlock = getOrCreateBlock(operands[1]);
2258 auto loc = createFileLineColLoc(opBuilder);
2259 uint32_t loopControl = operands[2];
2260
2261 if (!blockMergeInfo
2262 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2263 .second) {
2264 return emitError(
2265 unknownLoc,
2266 "a block cannot have more than one OpLoopMerge instruction");
2267 }
2268
2269 return success();
2270}
2271
2273 if (!curBlock) {
2274 return emitError(unknownLoc, "OpPhi must appear in a block");
2275 }
2276
2277 if (operands.size() < 4) {
2278 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2279 "and variable-parent pairs");
2280 }
2281
2282 // Create a block argument for this OpPhi instruction.
2283 Type blockArgType = getType(operands[0]);
2284 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2285 valueMap[operands[1]] = blockArg;
2286 LLVM_DEBUG(logger.startLine()
2287 << "[phi] created block argument " << blockArg
2288 << " id = " << operands[1] << " of type " << blockArgType << "\n");
2289
2290 // For each (value, predecessor) pair, insert the value to the predecessor's
2291 // blockPhiInfo entry so later we can fix the block argument there.
2292 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2293 uint32_t value = operands[i];
2294 Block *predecessor = getOrCreateBlock(operands[i + 1]);
2295 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2296 blockPhiInfo[predecessorTargetPair].push_back(value);
2297 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2298 << " with arg id = " << value << "\n");
2299 }
2300
2301 return success();
2302}
2303
2305 if (!curBlock)
2306 return emitError(unknownLoc, "OpSwitch must appear in a block");
2307
2308 if (operands.size() < 2)
2309 return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2310 "a default target");
2311
2312 if (operands.size() % 2)
2313 return emitError(unknownLoc,
2314 "OpSwitch must at have an even number of operands: "
2315 "selector, default target and any number of literal and "
2316 "label <id> pairs");
2317
2318 Value selector = getValue(operands[0]);
2319 Block *defaultBlock = getOrCreateBlock(operands[1]);
2320 Location loc = createFileLineColLoc(opBuilder);
2321
2322 SmallVector<int32_t> literals;
2323 SmallVector<Block *> blocks;
2324 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2325 literals.push_back(operands[i]);
2326 blocks.push_back(getOrCreateBlock(operands[i + 1]));
2327 }
2328
2329 SmallVector<ValueRange> targetOperands(blocks.size(), {});
2330 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2331 ArrayRef<Value>(), literals, blocks, targetOperands);
2332
2333 return success();
2334}
2335
2336namespace {
2337/// A class for putting all blocks in a structured selection/loop in a
2338/// spirv.mlir.selection/spirv.mlir.loop op.
2339class ControlFlowStructurizer {
2340public:
2341#ifndef NDEBUG
2342 ControlFlowStructurizer(Location loc, uint32_t control,
2343 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2344 Block *merge, Block *cont,
2345 llvm::ScopedPrinter &logger)
2346 : location(loc), control(control), blockMergeInfo(mergeInfo),
2347 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2348 logger(logger) {}
2349#else
2350 ControlFlowStructurizer(Location loc, uint32_t control,
2351 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2352 Block *merge, Block *cont)
2353 : location(loc), control(control), blockMergeInfo(mergeInfo),
2354 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2355#endif
2356
2357 /// Structurizes the loop at the given `headerBlock`.
2358 ///
2359 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2360 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2361 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2362 /// method will also update `mergeInfo` by remapping all blocks inside to the
2363 /// newly cloned ones inside structured control flow op's regions.
2364 LogicalResult structurize();
2365
2366private:
2367 /// Creates a new spirv.mlir.selection op at the beginning of the
2368 /// `mergeBlock`.
2369 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2370
2371 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2372 spirv::LoopOp createLoopOp(uint32_t loopControl);
2373
2374 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2375 void collectBlocksInConstruct();
2376
2377 Location location;
2378 uint32_t control;
2379
2380 spirv::BlockMergeInfoMap &blockMergeInfo;
2381
2382 Block *headerBlock;
2383 Block *mergeBlock;
2384 Block *continueBlock; // nullptr for spirv.mlir.selection
2385
2386 SetVector<Block *> constructBlocks;
2387
2388#ifndef NDEBUG
2389 /// A logger used to emit information during the deserialzation process.
2390 llvm::ScopedPrinter &logger;
2391#endif
2392};
2393} // namespace
2394
2395spirv::SelectionOp
2396ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2397 // Create a builder and set the insertion point to the beginning of the
2398 // merge block so that the newly created SelectionOp will be inserted there.
2399 OpBuilder builder(&mergeBlock->front());
2400
2401 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2402 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2403 selectionOp.addMergeBlock(builder);
2404
2405 return selectionOp;
2406}
2407
2408spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2409 // Create a builder and set the insertion point to the beginning of the
2410 // merge block so that the newly created LoopOp will be inserted there.
2411 OpBuilder builder(&mergeBlock->front());
2412
2413 auto control = static_cast<spirv::LoopControl>(loopControl);
2414 auto loopOp = spirv::LoopOp::create(builder, location, control);
2415 loopOp.addEntryAndMergeBlock(builder);
2416
2417 return loopOp;
2418}
2419
2420void ControlFlowStructurizer::collectBlocksInConstruct() {
2421 assert(constructBlocks.empty() && "expected empty constructBlocks");
2422
2423 // Put the header block in the work list first.
2424 constructBlocks.insert(headerBlock);
2425
2426 // For each item in the work list, add its successors excluding the merge
2427 // block.
2428 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2429 for (auto *successor : constructBlocks[i]->getSuccessors())
2430 if (successor != mergeBlock)
2431 constructBlocks.insert(successor);
2432 }
2433}
2434
2435LogicalResult ControlFlowStructurizer::structurize() {
2436 Operation *op = nullptr;
2437 bool isLoop = continueBlock != nullptr;
2438 if (isLoop) {
2439 if (auto loopOp = createLoopOp(control))
2440 op = loopOp.getOperation();
2441 } else {
2442 if (auto selectionOp = createSelectionOp(control))
2443 op = selectionOp.getOperation();
2444 }
2445 if (!op)
2446 return failure();
2447 Region &body = op->getRegion(0);
2448
2449 IRMapping mapper;
2450 // All references to the old merge block should be directed to the
2451 // selection/loop merge block in the SelectionOp/LoopOp's region.
2452 mapper.map(mergeBlock, &body.back());
2453
2454 collectBlocksInConstruct();
2455
2456 // We've identified all blocks belonging to the selection/loop's region. Now
2457 // need to "move" them into the selection/loop. Instead of really moving the
2458 // blocks, in the following we copy them and remap all values and branches.
2459 // This is because:
2460 // * Inserting a block into a region requires the block not in any region
2461 // before. But selections/loops can nest so we can create selection/loop ops
2462 // in a nested manner, which means some blocks may already be in a
2463 // selection/loop region when to be moved again.
2464 // * It's much trickier to fix up the branches into and out of the loop's
2465 // region: we need to treat not-moved blocks and moved blocks differently:
2466 // Not-moved blocks jumping to the loop header block need to jump to the
2467 // merge point containing the new loop op but not the loop continue block's
2468 // back edge. Moved blocks jumping out of the loop need to jump to the
2469 // merge block inside the loop region but not other not-moved blocks.
2470 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2471 // logic.
2472
2473 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2474 // block in this loop construct.
2475 OpBuilder builder(body);
2476 for (auto *block : constructBlocks) {
2477 // Create a block and insert it before the selection/loop merge block in the
2478 // SelectionOp/LoopOp's region.
2479 auto *newBlock = builder.createBlock(&body.back());
2480 mapper.map(block, newBlock);
2481 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2482 << " from block " << block << "\n");
2483 if (!isFnEntryBlock(block)) {
2484 for (BlockArgument blockArg : block->getArguments()) {
2485 auto newArg =
2486 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2487 mapper.map(blockArg, newArg);
2488 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2489 << blockArg << " to " << newArg << "\n");
2490 }
2491 } else {
2492 LLVM_DEBUG(logger.startLine()
2493 << "[cf] block " << block << " is a function entry block\n");
2494 }
2495
2496 for (auto &op : *block)
2497 newBlock->push_back(op.clone(mapper));
2498 }
2499
2500 // Go through all ops and remap the operands.
2501 auto remapOperands = [&](Operation *op) {
2502 for (auto &operand : op->getOpOperands())
2503 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2504 operand.set(mappedOp);
2505 for (auto &succOp : op->getBlockOperands())
2506 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2507 succOp.set(mappedOp);
2508 };
2509 for (auto &block : body)
2510 block.walk(remapOperands);
2511
2512 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2513 // the selection/loop construct into its region. Next we need to fix the
2514 // connections between this new SelectionOp/LoopOp with existing blocks.
2515
2516 // All existing incoming branches should go to the merge block, where the
2517 // SelectionOp/LoopOp resides right now.
2518 headerBlock->replaceAllUsesWith(mergeBlock);
2519
2520 LLVM_DEBUG({
2521 logger.startLine() << "[cf] after cloning and fixing references:\n";
2522 headerBlock->getParentOp()->print(logger.getOStream());
2523 logger.startLine() << "\n";
2524 });
2525
2526 if (isLoop) {
2527 if (!mergeBlock->args_empty()) {
2528 return mergeBlock->getParentOp()->emitError(
2529 "OpPhi in loop merge block unsupported");
2530 }
2531
2532 // The loop header block may have block arguments. Since now we place the
2533 // loop op inside the old merge block, we need to make sure the old merge
2534 // block has the same block argument list.
2535 for (BlockArgument blockArg : headerBlock->getArguments())
2536 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2537
2538 // If the loop header block has block arguments, make sure the spirv.Branch
2539 // op matches.
2540 SmallVector<Value, 4> blockArgs;
2541 if (!headerBlock->args_empty())
2542 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2543
2544 // The loop entry block should have a unconditional branch jumping to the
2545 // loop header block.
2546 builder.setInsertionPointToEnd(&body.front());
2547 spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2548 ArrayRef<Value>(blockArgs));
2549 }
2550
2551 // Values defined inside the selection region that need to be yielded outside
2552 // the region.
2553 SmallVector<Value> valuesToYield;
2554 // Outside uses of values that were sunk into the selection region. Those uses
2555 // will be replaced with values returned by the SelectionOp.
2556 SmallVector<Value> outsideUses;
2557
2558 // Move block arguments of the original block (`mergeBlock`) into the merge
2559 // block inside the selection (`body.back()`). Values produced by block
2560 // arguments will be yielded by the selection region. We do not update uses or
2561 // erase original block arguments yet. It will be done later in the code.
2562 //
2563 // Code below is not executed for loops as it would interfere with the logic
2564 // above. Currently block arguments in the merge block are not supported, but
2565 // instead, the code above copies those arguments from the header block into
2566 // the merge block. As such, running the code would yield those copied
2567 // arguments that is most likely not a desired behaviour. This may need to be
2568 // revisited in the future.
2569 if (!isLoop)
2570 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2571 // Create new block arguments in the last block ("merge block") of the
2572 // selection region. We create one argument for each argument in
2573 // `mergeBlock`. This new value will need to be yielded, and the original
2574 // value replaced, so add them to appropriate vectors.
2575 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2576 valuesToYield.push_back(body.back().getArguments().back());
2577 outsideUses.push_back(blockArg);
2578 }
2579
2580 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2581 // cleaned up.
2582 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2583 // First we need to drop all operands' references inside all blocks. This is
2584 // needed because we can have blocks referencing SSA values from one another.
2585 for (auto *block : constructBlocks)
2586 block->dropAllReferences();
2587
2588 // All internal uses should be removed from original blocks by now, so
2589 // whatever is left is an outside use and will need to be yielded from
2590 // the newly created selection / loop region.
2591 for (Block *block : constructBlocks) {
2592 for (Operation &op : *block) {
2593 if (!op.use_empty())
2594 for (Value result : op.getResults()) {
2595 valuesToYield.push_back(mapper.lookupOrNull(result));
2596 outsideUses.push_back(result);
2597 }
2598 }
2599 for (BlockArgument &arg : block->getArguments()) {
2600 if (!arg.use_empty()) {
2601 valuesToYield.push_back(mapper.lookupOrNull(arg));
2602 outsideUses.push_back(arg);
2603 }
2604 }
2605 }
2606
2607 assert(valuesToYield.size() == outsideUses.size());
2608
2609 // If we need to yield any values from the selection / loop region we will
2610 // take care of it here.
2611 if (!valuesToYield.empty()) {
2612 LLVM_DEBUG(logger.startLine()
2613 << "[cf] yielding values from the selection / loop region\n");
2614
2615 // Update `mlir.merge` with values to be yield.
2616 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2617 Operation *merge = llvm::getSingleElement(mergeOps);
2618 assert(merge);
2619 merge->setOperands(valuesToYield);
2620
2621 // MLIR does not allow changing the number of results of an operation, so
2622 // we create a new SelectionOp / LoopOp with required list of results and
2623 // move the region from the initial SelectionOp / LoopOp. The initial
2624 // operation is then removed. Since we move the region to the new op all
2625 // links between blocks and remapping we have previously done should be
2626 // preserved.
2627 builder.setInsertionPoint(&mergeBlock->front());
2628
2629 Operation *newOp = nullptr;
2630
2631 if (isLoop)
2632 newOp = spirv::LoopOp::create(builder, location,
2633 TypeRange(ValueRange(outsideUses)),
2634 static_cast<spirv::LoopControl>(control));
2635 else
2636 newOp = spirv::SelectionOp::create(
2637 builder, location, TypeRange(ValueRange(outsideUses)),
2638 static_cast<spirv::SelectionControl>(control));
2639
2640 newOp->getRegion(0).takeBody(body);
2641
2642 // Remove initial op and swap the pointer to the newly created one.
2643 op->erase();
2644 op = newOp;
2645
2646 // Update all outside uses to use results of the SelectionOp / LoopOp and
2647 // remove block arguments from the original merge block.
2648 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2649 outsideUses[i].replaceAllUsesWith(op->getResult(i));
2650
2651 // We do not support block arguments in loop merge block. Also running this
2652 // function with loop would break some of the loop specific code above
2653 // dealing with block arguments.
2654 if (!isLoop)
2655 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2656 }
2657
2658 // Check that whether some op in the to-be-erased blocks still has uses. Those
2659 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2660 // region. We cannot handle such cases given that once a value is sinked into
2661 // the SelectionOp/LoopOp's region, there is no escape for it.
2662 for (auto *block : constructBlocks) {
2663 if (!block->use_empty())
2664 return emitError(block->getParent()->getLoc(),
2665 "failed control flow structurization: "
2666 "block has uses outside of the "
2667 "enclosing selection/loop construct");
2668 for (Operation &op : *block)
2669 if (!op.use_empty())
2670 return op.emitOpError("failed control flow structurization: value has "
2671 "uses outside of the "
2672 "enclosing selection/loop construct");
2673 for (BlockArgument &arg : block->getArguments())
2674 if (!arg.use_empty())
2675 return emitError(arg.getLoc(), "failed control flow structurization: "
2676 "block argument has uses outside of the "
2677 "enclosing selection/loop construct");
2678 }
2679
2680 // Then erase all old blocks.
2681 for (auto *block : constructBlocks) {
2682 // We've cloned all blocks belonging to this construct into the structured
2683 // control flow op's region. Among these blocks, some may compose another
2684 // selection/loop. If so, they will be recorded within blockMergeInfo.
2685 // We need to update the pointers there to the newly remapped ones so we can
2686 // continue structurizing them later.
2687 //
2688 // We need to walk each block as constructBlocks do not include blocks
2689 // internal to ops already structured within those blocks. It is not
2690 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2691 // inside already structured selections/loops get invalidated and needs
2692 // updating, however the following example code can cause a crash (depending
2693 // on the structuring order), when the most inner selection is being
2694 // structured after the outer selection and loop have been already
2695 // structured:
2696 //
2697 // spirv.mlir.for {
2698 // // ...
2699 // spirv.mlir.selection {
2700 // // ..
2701 // // A selection region that hasn't been yet structured!
2702 // // ..
2703 // }
2704 // // ...
2705 // }
2706 //
2707 // If the loop gets structured after the outer selection, but before the
2708 // inner selection. Moving the already structured selection inside the loop
2709 // will invalidate the mergeInfo of the region that is not yet structured.
2710 // Just going over constructBlocks will not check and updated header blocks
2711 // inside the already structured selection region. Walking block fixes that.
2712 //
2713 // TODO: If structuring was done in a fixed order starting with inner
2714 // most constructs this most likely not be an issue and the whole code
2715 // section could be removed. However, with the current non-deterministic
2716 // order this is not possible.
2717 //
2718 // TODO: The asserts in the following assumes input SPIR-V blob forms
2719 // correctly nested selection/loop constructs. We should relax this and
2720 // support error cases better.
2721 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2722 auto it = blockMergeInfo.find(block);
2723 if (it != blockMergeInfo.end()) {
2724 // Use the original location for nested selection/loop ops.
2725 Location loc = it->second.loc;
2726
2727 Block *newHeader = mapper.lookupOrNull(block);
2728 if (!newHeader)
2729 return emitError(loc, "failed control flow structurization: nested "
2730 "loop header block should be remapped!");
2731
2732 Block *newContinue = it->second.continueBlock;
2733 if (newContinue) {
2734 newContinue = mapper.lookupOrNull(newContinue);
2735 if (!newContinue)
2736 return emitError(loc, "failed control flow structurization: nested "
2737 "loop continue block should be remapped!");
2738 }
2739
2740 Block *newMerge = it->second.mergeBlock;
2741 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2742 newMerge = mappedTo;
2743
2744 // The iterator should be erased before adding a new entry into
2745 // blockMergeInfo to avoid iterator invalidation.
2746 blockMergeInfo.erase(it);
2747 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2748 newContinue);
2749 }
2750
2751 return WalkResult::advance();
2752 };
2753
2754 if (block->walk(updateMergeInfo).wasInterrupted())
2755 return failure();
2756
2757 // The structured selection/loop's entry block does not have arguments.
2758 // If the function's header block is also part of the structured control
2759 // flow, we cannot just simply erase it because it may contain arguments
2760 // matching the function signature and used by the cloned blocks.
2761 if (isFnEntryBlock(block)) {
2762 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2763 << " to only contain a spirv.Branch op\n");
2764 // Still keep the function entry block for the potential block arguments,
2765 // but replace all ops inside with a branch to the merge block.
2766 block->clear();
2767 builder.setInsertionPointToEnd(block);
2768 spirv::BranchOp::create(builder, location, mergeBlock);
2769 } else {
2770 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2771 block->erase();
2772 }
2773 }
2774
2775 LLVM_DEBUG(logger.startLine()
2776 << "[cf] after structurizing construct with header block "
2777 << headerBlock << ":\n"
2778 << *op << "\n");
2779
2780 return success();
2781}
2782
2784 LLVM_DEBUG({
2785 logger.startLine()
2786 << "//----- [phi] start wiring up block arguments -----//\n";
2787 logger.indent();
2788 });
2789
2790 OpBuilder::InsertionGuard guard(opBuilder);
2791
2792 for (const auto &info : blockPhiInfo) {
2793 Block *block = info.first.first;
2794 Block *target = info.first.second;
2795 const BlockPhiInfo &phiInfo = info.second;
2796 LLVM_DEBUG({
2797 logger.startLine() << "[phi] block " << block << "\n";
2798 logger.startLine() << "[phi] before creating block argument:\n";
2799 block->getParentOp()->print(logger.getOStream());
2800 logger.startLine() << "\n";
2801 });
2802
2803 // Set insertion point to before this block's terminator early because we
2804 // may materialize ops via getValue() call.
2805 auto *op = block->getTerminator();
2806 opBuilder.setInsertionPoint(op);
2807
2808 SmallVector<Value, 4> blockArgs;
2809 blockArgs.reserve(phiInfo.size());
2810 for (uint32_t valueId : phiInfo) {
2811 if (Value value = getValue(valueId)) {
2812 blockArgs.push_back(value);
2813 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2814 << " id = " << valueId << "\n");
2815 } else {
2816 return emitError(unknownLoc, "OpPhi references undefined value!");
2817 }
2818 }
2819
2820 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2821 // Replace the previous branch op with a new one with block arguments.
2822 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2823 branchOp.getTarget(), blockArgs);
2824 branchOp.erase();
2825 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2826 assert((branchCondOp.getTrueBlock() == target ||
2827 branchCondOp.getFalseBlock() == target) &&
2828 "expected target to be either the true or false target");
2829 if (target == branchCondOp.getTrueTarget())
2830 spirv::BranchConditionalOp::create(
2831 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2832 blockArgs, branchCondOp.getFalseBlockArguments(),
2833 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2834 branchCondOp.getFalseTarget());
2835 else
2836 spirv::BranchConditionalOp::create(
2837 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2838 branchCondOp.getTrueBlockArguments(), blockArgs,
2839 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2840 branchCondOp.getFalseBlock());
2841
2842 branchCondOp.erase();
2843 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2844 if (target == switchOp.getDefaultTarget()) {
2845 SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
2846 DenseIntElementsAttr literals =
2847 switchOp.getLiterals().value_or(DenseIntElementsAttr());
2848 spirv::SwitchOp::create(
2849 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2850 switchOp.getDefaultTarget(), blockArgs, literals,
2851 switchOp.getTargets(), targetOperands);
2852 switchOp.erase();
2853 } else {
2854 SuccessorRange targets = switchOp.getTargets();
2855 auto it = llvm::find(targets, target);
2856 assert(it != targets.end());
2857 size_t index = std::distance(targets.begin(), it);
2858 switchOp.getTargetOperandsMutable(index).assign(blockArgs);
2859 }
2860 } else {
2861 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2862 }
2863
2864 LLVM_DEBUG({
2865 logger.startLine() << "[phi] after creating block argument:\n";
2866 block->getParentOp()->print(logger.getOStream());
2867 logger.startLine() << "\n";
2868 });
2869 }
2870 blockPhiInfo.clear();
2871
2872 LLVM_DEBUG({
2873 logger.unindent();
2874 logger.startLine()
2875 << "//--- [phi] completed wiring up block arguments ---//\n";
2876 });
2877 return success();
2878}
2879
2881 // Create a copy, so we can modify keys in the original.
2882 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2883 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2884 it != e; ++it) {
2885 auto &[block, mergeInfo] = *it;
2886
2887 // Skip processing loop regions. For loop regions continueBlock is non-null.
2888 if (mergeInfo.continueBlock)
2889 continue;
2890
2891 if (!block->mightHaveTerminator())
2892 continue;
2893
2894 Operation *terminator = block->getTerminator();
2895 assert(terminator);
2896
2897 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2898 continue;
2899
2900 // Check if the current header block is a merge block of another construct.
2901 bool splitHeaderMergeBlock = false;
2902 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2903 if (mergeInfo.mergeBlock == block)
2904 splitHeaderMergeBlock = true;
2905 }
2906
2907 // Do not split a block that only contains a conditional branch / switch,
2908 // unless it is also a merge block of another construct - in that case we
2909 // want to split the block. We do not want two constructs to share header /
2910 // merge block.
2911 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2912 Block *newBlock = block->splitBlock(terminator);
2913 OpBuilder builder(block, block->end());
2914 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2915
2916 // After splitting we need to update the map to use the new block as a
2917 // header.
2918 blockMergeInfo.erase(block);
2919 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2920 }
2921 }
2922
2923 return success();
2924}
2925
2927 if (!options.enableControlFlowStructurization) {
2928 LLVM_DEBUG(
2929 {
2930 logger.startLine()
2931 << "//----- [cf] skip structurizing control flow -----//\n";
2932 logger.indent();
2933 });
2934 return success();
2935 }
2936
2937 LLVM_DEBUG({
2938 logger.startLine()
2939 << "//----- [cf] start structurizing control flow -----//\n";
2940 logger.indent();
2941 });
2942
2943 LLVM_DEBUG({
2944 logger.startLine() << "[cf] split conditional blocks\n";
2945 logger.startLine() << "\n";
2946 });
2947
2948 if (failed(splitSelectionHeader())) {
2949 return failure();
2950 }
2951
2952 while (!blockMergeInfo.empty()) {
2953 Block *headerBlock = blockMergeInfo.begin()->first;
2954 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2955
2956 LLVM_DEBUG({
2957 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2958 headerBlock->print(logger.getOStream());
2959 logger.startLine() << "\n";
2960 });
2961
2962 auto *mergeBlock = mergeInfo.mergeBlock;
2963 assert(mergeBlock && "merge block cannot be nullptr");
2964 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2965 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2966 LLVM_DEBUG({
2967 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2968 mergeBlock->print(logger.getOStream());
2969 logger.startLine() << "\n";
2970 });
2971
2972 auto *continueBlock = mergeInfo.continueBlock;
2973 LLVM_DEBUG(if (continueBlock) {
2974 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2975 continueBlock->print(logger.getOStream());
2976 logger.startLine() << "\n";
2977 });
2978 // Erase this case before calling into structurizer, who will update
2979 // blockMergeInfo.
2980 blockMergeInfo.erase(blockMergeInfo.begin());
2981 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2982 blockMergeInfo, headerBlock,
2983 mergeBlock, continueBlock
2984#ifndef NDEBUG
2985 ,
2986 logger
2987#endif
2988 );
2989 if (failed(structurizer.structurize()))
2990 return failure();
2991 }
2992
2993 LLVM_DEBUG({
2994 logger.unindent();
2995 logger.startLine()
2996 << "//--- [cf] completed structurizing control flow ---//\n";
2997 });
2998 return success();
2999}
3000
3001//===----------------------------------------------------------------------===//
3002// Debug
3003//===----------------------------------------------------------------------===//
3004
3006 if (!debugLine)
3007 return unknownLoc;
3008
3009 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3010 if (fileName.empty())
3011 fileName = "<unknown>";
3012 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
3013 debugLine->column);
3014}
3015
3016LogicalResult
3018 // According to SPIR-V spec:
3019 // "This location information applies to the instructions physically
3020 // following this instruction, up to the first occurrence of any of the
3021 // following: the next end of block, the next OpLine instruction, or the next
3022 // OpNoLine instruction."
3023 if (operands.size() != 3)
3024 return emitError(unknownLoc, "OpLine must have 3 operands");
3025 debugLine = DebugLine{operands[0], operands[1], operands[2]};
3026 return success();
3027}
3028
3029void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
3030
3031LogicalResult
3033 if (operands.size() < 2)
3034 return emitError(unknownLoc, "OpString needs at least 2 operands");
3035
3036 if (!debugInfoMap.lookup(operands[0]).empty())
3037 return emitError(unknownLoc,
3038 "duplicate debug string found for result <id> ")
3039 << operands[0];
3040
3041 unsigned wordIndex = 1;
3042 StringRef debugString = decodeStringLiteral(operands, wordIndex);
3043 if (wordIndex != operands.size())
3044 return emitError(unknownLoc,
3045 "unexpected trailing words in OpString instruction");
3046
3047 debugInfoMap[operands[0]] = debugString;
3048 return success();
3049}
return success()
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
iterator begin()
Definition Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< BlockOperand > getBlockOperands()
Definition Operation.h:695
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
iterator end()
Definition Region.h:56
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:241
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.