MLIR 22.0.0git
Deserializer.cpp
Go to the documentation of this file.
1//===- Deserializer.cpp - MLIR SPIR-V Deserializer ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the SPIR-V binary to MLIR SPIR-V module deserializer.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Deserializer.h"
14
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Location.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/Sequence.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/StringExtras.h"
27#include "llvm/ADT/bit.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/SaveAndRestore.h"
30#include "llvm/Support/raw_ostream.h"
31#include <optional>
32
33using namespace mlir;
34
35#define DEBUG_TYPE "spirv-deserialization"
36
37//===----------------------------------------------------------------------===//
38// Utility Functions
39//===----------------------------------------------------------------------===//
40
41/// Returns true if the given `block` is a function entry block.
42static inline bool isFnEntryBlock(Block *block) {
43 return block->isEntryBlock() &&
44 isa_and_nonnull<spirv::FuncOp>(block->getParentOp());
45}
46
47//===----------------------------------------------------------------------===//
48// Deserializer Method Definitions
49//===----------------------------------------------------------------------===//
50
51spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
52 MLIRContext *context,
54 : binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
55 module(createModuleOp()), opBuilder(module->getRegion()), options(options)
56#ifndef NDEBUG
57 ,
58 logger(llvm::dbgs())
59#endif
60{
61}
62
63LogicalResult spirv::Deserializer::deserialize() {
64 LLVM_DEBUG({
65 logger.resetIndent();
66 logger.startLine()
67 << "//+++---------- start deserialization ----------+++//\n";
68 });
69
70 if (failed(processHeader()))
71 return failure();
72
73 spirv::Opcode opcode = spirv::Opcode::OpNop;
74 ArrayRef<uint32_t> operands;
75 auto binarySize = binary.size();
76 while (curOffset < binarySize) {
77 // Slice the next instruction out and populate `opcode` and `operands`.
78 // Internally this also updates `curOffset`.
79 if (failed(sliceInstruction(opcode, operands)))
80 return failure();
81
82 if (failed(processInstruction(opcode, operands)))
83 return failure();
84 }
85
86 assert(curOffset == binarySize &&
87 "deserializer should never index beyond the binary end");
88
89 for (auto &deferred : deferredInstructions) {
90 if (failed(processInstruction(deferred.first, deferred.second, false))) {
91 return failure();
92 }
93 }
94
95 attachVCETriple();
96
97 LLVM_DEBUG(logger.startLine()
98 << "//+++-------- completed deserialization --------+++//\n");
99 return success();
100}
101
102OwningOpRef<spirv::ModuleOp> spirv::Deserializer::collect() {
103 return std::move(module);
104}
105
106//===----------------------------------------------------------------------===//
107// Module structure
108//===----------------------------------------------------------------------===//
109
110OwningOpRef<spirv::ModuleOp> spirv::Deserializer::createModuleOp() {
111 OpBuilder builder(context);
112 OperationState state(unknownLoc, spirv::ModuleOp::getOperationName());
113 spirv::ModuleOp::build(builder, state);
114 return cast<spirv::ModuleOp>(Operation::create(state));
115}
116
117LogicalResult spirv::Deserializer::processHeader() {
118 if (binary.size() < spirv::kHeaderWordCount)
119 return emitError(unknownLoc,
120 "SPIR-V binary module must have a 5-word header");
121
122 if (binary[0] != spirv::kMagicNumber)
123 return emitError(unknownLoc, "incorrect magic number");
124
125 // Version number bytes: 0 | major number | minor number | 0
126 uint32_t majorVersion = (binary[1] << 8) >> 24;
127 uint32_t minorVersion = (binary[1] << 16) >> 24;
128 if (majorVersion == 1) {
129 switch (minorVersion) {
130#define MIN_VERSION_CASE(v) \
131 case v: \
132 version = spirv::Version::V_1_##v; \
133 break
134
142#undef MIN_VERSION_CASE
143 default:
144 return emitError(unknownLoc, "unsupported SPIR-V minor version: ")
145 << minorVersion;
146 }
147 } else {
148 return emitError(unknownLoc, "unsupported SPIR-V major version: ")
149 << majorVersion;
150 }
151
152 // TODO: generator number, bound, schema
153 curOffset = spirv::kHeaderWordCount;
154 return success();
155}
156
157LogicalResult
158spirv::Deserializer::processCapability(ArrayRef<uint32_t> operands) {
159 if (operands.size() != 1)
160 return emitError(unknownLoc, "OpCapability must have one parameter");
161
162 auto cap = spirv::symbolizeCapability(operands[0]);
163 if (!cap)
164 return emitError(unknownLoc, "unknown capability: ") << operands[0];
165
166 capabilities.insert(*cap);
167 return success();
168}
169
170LogicalResult spirv::Deserializer::processExtension(ArrayRef<uint32_t> words) {
171 if (words.empty()) {
172 return emitError(
173 unknownLoc,
174 "OpExtension must have a literal string for the extension name");
175 }
176
177 unsigned wordIndex = 0;
178 StringRef extName = decodeStringLiteral(words, wordIndex);
179 if (wordIndex != words.size())
180 return emitError(unknownLoc,
181 "unexpected trailing words in OpExtension instruction");
182 auto ext = spirv::symbolizeExtension(extName);
183 if (!ext)
184 return emitError(unknownLoc, "unknown extension: ") << extName;
185
186 extensions.insert(*ext);
187 return success();
188}
189
190LogicalResult
191spirv::Deserializer::processExtInstImport(ArrayRef<uint32_t> words) {
192 if (words.size() < 2) {
193 return emitError(unknownLoc,
194 "OpExtInstImport must have a result <id> and a literal "
195 "string for the extended instruction set name");
196 }
197
198 unsigned wordIndex = 1;
199 extendedInstSets[words[0]] = decodeStringLiteral(words, wordIndex);
200 if (wordIndex != words.size()) {
201 return emitError(unknownLoc,
202 "unexpected trailing words in OpExtInstImport");
203 }
204 return success();
205}
206
207void spirv::Deserializer::attachVCETriple() {
208 (*module)->setAttr(
209 spirv::ModuleOp::getVCETripleAttrName(),
210 spirv::VerCapExtAttr::get(version, capabilities.getArrayRef(),
211 extensions.getArrayRef(), context));
212}
213
214LogicalResult
215spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
216 if (operands.size() != 2)
217 return emitError(unknownLoc, "OpMemoryModel must have two operands");
218
219 (*module)->setAttr(
220 module->getAddressingModelAttrName(),
221 opBuilder.getAttr<spirv::AddressingModelAttr>(
222 static_cast<spirv::AddressingModel>(operands.front())));
223
224 (*module)->setAttr(module->getMemoryModelAttrName(),
225 opBuilder.getAttr<spirv::MemoryModelAttr>(
226 static_cast<spirv::MemoryModel>(operands.back())));
227
228 return success();
229}
230
231template <typename AttrTy, typename EnumAttrTy, typename EnumTy>
233 Location loc, OpBuilder &opBuilder,
235 StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) {
236 if (words.size() != 4) {
237 return emitError(loc, "OpDecoration with ")
238 << decorationName << "needs a cache control integer literal and a "
239 << cacheControlKind << " cache control literal";
240 }
241 unsigned cacheLevel = words[2];
242 auto cacheControlAttr = static_cast<EnumTy>(words[3]);
243 auto value = opBuilder.getAttr<AttrTy>(cacheLevel, cacheControlAttr);
245 if (auto attrList =
246 dyn_cast_or_null<ArrayAttr>(decorations[words[0]].get(symbol)))
247 llvm::append_range(attrs, attrList);
248 attrs.push_back(value);
249 decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs));
250 return success();
251}
252
253LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
254 // TODO: This function should also be auto-generated. For now, since only a
255 // few decorations are processed/handled in a meaningful manner, going with a
256 // manual implementation.
257 if (words.size() < 2) {
258 return emitError(
259 unknownLoc, "OpDecorate must have at least result <id> and Decoration");
260 }
261 auto decorationName =
262 stringifyDecoration(static_cast<spirv::Decoration>(words[1]));
263 if (decorationName.empty()) {
264 return emitError(unknownLoc, "invalid Decoration code : ") << words[1];
265 }
266 auto symbol = getSymbolDecoration(decorationName);
267 switch (static_cast<spirv::Decoration>(words[1])) {
268 case spirv::Decoration::FPFastMathMode:
269 if (words.size() != 3) {
270 return emitError(unknownLoc, "OpDecorate with ")
271 << decorationName << " needs a single integer literal";
272 }
273 decorations[words[0]].set(
274 symbol, FPFastMathModeAttr::get(opBuilder.getContext(),
275 static_cast<FPFastMathMode>(words[2])));
276 break;
277 case spirv::Decoration::FPRoundingMode:
278 if (words.size() != 3) {
279 return emitError(unknownLoc, "OpDecorate with ")
280 << decorationName << " needs a single integer literal";
281 }
282 decorations[words[0]].set(
283 symbol, FPRoundingModeAttr::get(opBuilder.getContext(),
284 static_cast<FPRoundingMode>(words[2])));
285 break;
286 case spirv::Decoration::DescriptorSet:
287 case spirv::Decoration::Binding:
288 if (words.size() != 3) {
289 return emitError(unknownLoc, "OpDecorate with ")
290 << decorationName << " needs a single integer literal";
291 }
292 decorations[words[0]].set(
293 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
294 break;
295 case spirv::Decoration::BuiltIn:
296 if (words.size() != 3) {
297 return emitError(unknownLoc, "OpDecorate with ")
298 << decorationName << " needs a single integer literal";
299 }
300 decorations[words[0]].set(
301 symbol, opBuilder.getStringAttr(
302 stringifyBuiltIn(static_cast<spirv::BuiltIn>(words[2]))));
303 break;
304 case spirv::Decoration::ArrayStride:
305 if (words.size() != 3) {
306 return emitError(unknownLoc, "OpDecorate with ")
307 << decorationName << " needs a single integer literal";
308 }
309 typeDecorations[words[0]] = words[2];
310 break;
311 case spirv::Decoration::LinkageAttributes: {
312 if (words.size() < 4) {
313 return emitError(unknownLoc, "OpDecorate with ")
314 << decorationName
315 << " needs at least 1 string and 1 integer literal";
316 }
317 // LinkageAttributes has two parameters ["linkageName", linkageType]
318 // e.g., OpDecorate %imported_func LinkageAttributes "outside.func" Import
319 // "linkageName" is a stringliteral encoded as uint32_t,
320 // hence the size of name is variable length which results in words.size()
321 // being variable length, words.size() = 3 + strlen(name)/4 + 1 or
322 // 3 + ceildiv(strlen(name), 4).
323 unsigned wordIndex = 2;
324 auto linkageName = spirv::decodeStringLiteral(words, wordIndex).str();
325 auto linkageTypeAttr = opBuilder.getAttr<::mlir::spirv::LinkageTypeAttr>(
326 static_cast<::mlir::spirv::LinkageType>(words[wordIndex++]));
327 auto linkageAttr = opBuilder.getAttr<::mlir::spirv::LinkageAttributesAttr>(
328 StringAttr::get(context, linkageName), linkageTypeAttr);
329 decorations[words[0]].set(symbol, dyn_cast<Attribute>(linkageAttr));
330 break;
331 }
332 case spirv::Decoration::Aliased:
333 case spirv::Decoration::AliasedPointer:
334 case spirv::Decoration::Block:
335 case spirv::Decoration::BufferBlock:
336 case spirv::Decoration::Flat:
337 case spirv::Decoration::NonReadable:
338 case spirv::Decoration::NonWritable:
339 case spirv::Decoration::NoPerspective:
340 case spirv::Decoration::NoSignedWrap:
341 case spirv::Decoration::NoUnsignedWrap:
342 case spirv::Decoration::RelaxedPrecision:
343 case spirv::Decoration::Restrict:
344 case spirv::Decoration::RestrictPointer:
345 case spirv::Decoration::NoContraction:
346 case spirv::Decoration::Constant:
347 case spirv::Decoration::Invariant:
348 case spirv::Decoration::Patch:
349 case spirv::Decoration::Coherent:
350 if (words.size() != 2) {
351 return emitError(unknownLoc, "OpDecoration with ")
352 << decorationName << "needs a single target <id>";
353 }
354 decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
355 break;
356 case spirv::Decoration::Location:
357 case spirv::Decoration::SpecId:
358 case spirv::Decoration::Index:
359 if (words.size() != 3) {
360 return emitError(unknownLoc, "OpDecoration with ")
361 << decorationName << "needs a single integer literal";
362 }
363 decorations[words[0]].set(
364 symbol, opBuilder.getI32IntegerAttr(static_cast<int32_t>(words[2])));
365 break;
366 case spirv::Decoration::CacheControlLoadINTEL: {
367 LogicalResult res = deserializeCacheControlDecoration<
368 CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>(
369 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
370 "load");
371 if (failed(res))
372 return res;
373 break;
374 }
375 case spirv::Decoration::CacheControlStoreINTEL: {
376 LogicalResult res = deserializeCacheControlDecoration<
377 CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>(
378 unknownLoc, opBuilder, decorations, words, symbol, decorationName,
379 "store");
380 if (failed(res))
381 return res;
382 break;
383 }
384 default:
385 return emitError(unknownLoc, "unhandled Decoration : '") << decorationName;
386 }
387 return success();
388}
389
390LogicalResult
391spirv::Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
392 // The binary layout of OpMemberDecorate is different comparing to OpDecorate
393 if (words.size() < 3) {
394 return emitError(unknownLoc,
395 "OpMemberDecorate must have at least 3 operands");
396 }
397
398 auto decoration = static_cast<spirv::Decoration>(words[2]);
399 if (decoration == spirv::Decoration::Offset && words.size() != 4) {
400 return emitError(unknownLoc,
401 " missing offset specification in OpMemberDecorate with "
402 "Offset decoration");
403 }
404 ArrayRef<uint32_t> decorationOperands;
405 if (words.size() > 3) {
406 decorationOperands = words.slice(3);
407 }
408 memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands;
409 return success();
410}
411
412LogicalResult spirv::Deserializer::processMemberName(ArrayRef<uint32_t> words) {
413 if (words.size() < 3) {
414 return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
415 }
416 unsigned wordIndex = 2;
417 auto name = decodeStringLiteral(words, wordIndex);
418 if (wordIndex != words.size()) {
419 return emitError(unknownLoc,
420 "unexpected trailing words in OpMemberName instruction");
421 }
422 memberNameMap[words[0]][words[1]] = name;
423 return success();
424}
425
427 uint32_t argID, SmallVectorImpl<Attribute> &argAttrs, size_t argIndex) {
428 if (!decorations.contains(argID)) {
429 argAttrs[argIndex] = DictionaryAttr::get(context, {});
430 return success();
431 }
432
433 spirv::DecorationAttr foundDecorationAttr;
434 for (NamedAttribute decAttr : decorations[argID]) {
435 for (auto decoration :
436 {spirv::Decoration::Aliased, spirv::Decoration::Restrict,
437 spirv::Decoration::AliasedPointer,
438 spirv::Decoration::RestrictPointer}) {
439
440 if (decAttr.getName() !=
441 getSymbolDecoration(stringifyDecoration(decoration)))
442 continue;
443
444 if (foundDecorationAttr)
445 return emitError(unknownLoc,
446 "more than one Aliased/Restrict decorations for "
447 "function argument with result <id> ")
448 << argID;
449
450 foundDecorationAttr = spirv::DecorationAttr::get(context, decoration);
451 break;
452 }
453
454 if (decAttr.getName() == getSymbolDecoration(stringifyDecoration(
455 spirv::Decoration::RelaxedPrecision))) {
456 // TODO: Current implementation supports only one decoration per function
457 // parameter so RelaxedPrecision cannot be applied at the same time as,
458 // for example, Aliased/Restrict/etc. This should be relaxed to allow any
459 // combination of decoration allowed by the spec to be supported.
460 if (foundDecorationAttr)
461 return emitError(unknownLoc, "already found a decoration for function "
462 "argument with result <id> ")
463 << argID;
464
465 foundDecorationAttr = spirv::DecorationAttr::get(
466 context, spirv::Decoration::RelaxedPrecision);
467 }
468 }
469
470 if (!foundDecorationAttr)
471 return emitError(unknownLoc, "unimplemented decoration support for "
472 "function argument with result <id> ")
473 << argID;
474
475 NamedAttribute attr(StringAttr::get(context, spirv::DecorationAttr::name),
476 foundDecorationAttr);
477 argAttrs[argIndex] = DictionaryAttr::get(context, attr);
478 return success();
479}
480
481LogicalResult
483 if (curFunction) {
484 return emitError(unknownLoc, "found function inside function");
485 }
486
487 // Get the result type
488 if (operands.size() != 4) {
489 return emitError(unknownLoc, "OpFunction must have 4 parameters");
490 }
491 Type resultType = getType(operands[0]);
492 if (!resultType) {
493 return emitError(unknownLoc, "undefined result type from <id> ")
494 << operands[0];
495 }
496
497 uint32_t fnID = operands[1];
498 if (funcMap.count(fnID)) {
499 return emitError(unknownLoc, "duplicate function definition/declaration");
500 }
501
502 auto fnControl = spirv::symbolizeFunctionControl(operands[2]);
503 if (!fnControl) {
504 return emitError(unknownLoc, "unknown Function Control: ") << operands[2];
505 }
506
507 Type fnType = getType(operands[3]);
508 if (!fnType || !isa<FunctionType>(fnType)) {
509 return emitError(unknownLoc, "unknown function type from <id> ")
510 << operands[3];
511 }
512 auto functionType = cast<FunctionType>(fnType);
513
514 if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
515 (functionType.getNumResults() == 1 &&
516 functionType.getResult(0) != resultType)) {
517 return emitError(unknownLoc, "mismatch in function type ")
518 << functionType << " and return type " << resultType << " specified";
519 }
520
521 std::string fnName = getFunctionSymbol(fnID);
522 auto funcOp = spirv::FuncOp::create(opBuilder, unknownLoc, fnName,
523 functionType, fnControl.value());
524 // Processing other function attributes.
525 if (decorations.count(fnID)) {
526 for (auto attr : decorations[fnID].getAttrs()) {
527 funcOp->setAttr(attr.getName(), attr.getValue());
528 }
529 }
530 curFunction = funcMap[fnID] = funcOp;
531 auto *entryBlock = funcOp.addEntryBlock();
532 LLVM_DEBUG({
533 logger.startLine()
534 << "//===-------------------------------------------===//\n";
535 logger.startLine() << "[fn] name: " << fnName << "\n";
536 logger.startLine() << "[fn] type: " << fnType << "\n";
537 logger.startLine() << "[fn] ID: " << fnID << "\n";
538 logger.startLine() << "[fn] entry block: " << entryBlock << "\n";
539 logger.indent();
540 });
541
542 SmallVector<Attribute> argAttrs;
543 argAttrs.resize(functionType.getNumInputs());
544
545 // Parse the op argument instructions
546 if (functionType.getNumInputs()) {
547 for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
548 auto argType = functionType.getInput(i);
549 spirv::Opcode opcode = spirv::Opcode::OpNop;
550 ArrayRef<uint32_t> operands;
551 if (failed(sliceInstruction(opcode, operands,
552 spirv::Opcode::OpFunctionParameter))) {
553 return failure();
554 }
555 if (opcode != spirv::Opcode::OpFunctionParameter) {
556 return emitError(
557 unknownLoc,
558 "missing OpFunctionParameter instruction for argument ")
559 << i;
560 }
561 if (operands.size() != 2) {
562 return emitError(
563 unknownLoc,
564 "expected result type and result <id> for OpFunctionParameter");
565 }
566 auto argDefinedType = getType(operands[0]);
567 if (!argDefinedType || argDefinedType != argType) {
568 return emitError(unknownLoc,
569 "mismatch in argument type between function type "
570 "definition ")
571 << functionType << " and argument type definition "
572 << argDefinedType << " at argument " << i;
573 }
574 if (getValue(operands[1])) {
575 return emitError(unknownLoc, "duplicate definition of result <id> ")
576 << operands[1];
577 }
578 if (failed(setFunctionArgAttrs(operands[1], argAttrs, i))) {
579 return failure();
580 }
581
582 auto argValue = funcOp.getArgument(i);
583 valueMap[operands[1]] = argValue;
584 }
585 }
586
587 if (llvm::any_of(argAttrs, [](Attribute attr) {
588 auto argAttr = cast<DictionaryAttr>(attr);
589 return !argAttr.empty();
590 }))
591 funcOp.setArgAttrsAttr(ArrayAttr::get(context, argAttrs));
592
593 // entryBlock is needed to access the arguments, Once that is done, we can
594 // erase the block for functions with 'Import' LinkageAttributes, since these
595 // are essentially function declarations, so they have no body.
596 auto linkageAttr = funcOp.getLinkageAttributes();
597 auto hasImportLinkage =
598 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
599 spirv::LinkageType::Import);
600 if (hasImportLinkage)
601 funcOp.eraseBody();
602
603 // RAII guard to reset the insertion point to the module's region after
604 // deserializing the body of this function.
605 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
606
607 spirv::Opcode opcode = spirv::Opcode::OpNop;
608 ArrayRef<uint32_t> instOperands;
609
610 // Special handling for the entry block. We need to make sure it starts with
611 // an OpLabel instruction. The entry block takes the same parameters as the
612 // function. All other blocks do not take any parameter. We have already
613 // created the entry block, here we need to register it to the correct label
614 // <id>.
615 if (failed(sliceInstruction(opcode, instOperands,
616 spirv::Opcode::OpFunctionEnd))) {
617 return failure();
618 }
619 if (opcode == spirv::Opcode::OpFunctionEnd) {
620 return processFunctionEnd(instOperands);
621 }
622 if (opcode != spirv::Opcode::OpLabel) {
623 return emitError(unknownLoc, "a basic block must start with OpLabel");
624 }
625 if (instOperands.size() != 1) {
626 return emitError(unknownLoc, "OpLabel should only have result <id>");
627 }
628 blockMap[instOperands[0]] = entryBlock;
629 if (failed(processLabel(instOperands))) {
630 return failure();
631 }
632
633 // Then process all the other instructions in the function until we hit
634 // OpFunctionEnd.
635 while (succeeded(sliceInstruction(opcode, instOperands,
636 spirv::Opcode::OpFunctionEnd)) &&
637 opcode != spirv::Opcode::OpFunctionEnd) {
638 if (failed(processInstruction(opcode, instOperands))) {
639 return failure();
640 }
641 }
642 if (opcode != spirv::Opcode::OpFunctionEnd) {
643 return failure();
644 }
645
646 return processFunctionEnd(instOperands);
647}
648
649LogicalResult
651 // Process OpFunctionEnd.
652 if (!operands.empty()) {
653 return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
654 }
655
656 // Wire up block arguments from OpPhi instructions.
657 // Put all structured control flow in spirv.mlir.selection/spirv.mlir.loop
658 // ops.
659 if (failed(wireUpBlockArgument()) || failed(structurizeControlFlow())) {
660 return failure();
661 }
662
663 curBlock = nullptr;
664 curFunction = std::nullopt;
665
666 LLVM_DEBUG({
667 logger.unindent();
668 logger.startLine()
669 << "//===-------------------------------------------===//\n";
670 });
671 return success();
672}
673
674LogicalResult
676 if (operands.size() < 2) {
677 return emitError(unknownLoc,
678 "missing graph defintion in OpGraphEntryPointARM");
679 }
680
681 unsigned wordIndex = 0;
682 uint32_t graphID = operands[wordIndex++];
683 if (!graphMap.contains(graphID)) {
684 return emitError(unknownLoc,
685 "missing graph definition/declaration with id ")
686 << graphID;
687 }
688
689 spirv::GraphARMOp graphARM = graphMap[graphID];
690 StringRef name = decodeStringLiteral(operands, wordIndex);
691 graphARM.setSymName(name);
692 graphARM.setEntryPoint(true);
693
695 for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
696 if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
697 interface.push_back(SymbolRefAttr::get(arg.getOperation()));
698 } else {
699 return emitError(unknownLoc, "undefined result <id> ")
700 << operands[wordIndex] << " while decoding OpGraphEntryPoint";
701 }
702 }
703
704 // RAII guard to reset the insertion point to previous value when done.
705 OpBuilder::InsertionGuard insertionGuard(opBuilder);
706 opBuilder.setInsertionPoint(graphARM);
707 spirv::GraphEntryPointARMOp::create(
708 opBuilder, unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
709 opBuilder.getArrayAttr(interface));
710
711 return success();
712}
713
714LogicalResult
716 if (curGraph) {
717 return emitError(unknownLoc, "found graph inside graph");
718 }
719 // Get the result type.
720 if (operands.size() < 2) {
721 return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
722 }
723
724 Type type = getType(operands[0]);
725 if (!type || !isa<GraphType>(type)) {
726 return emitError(unknownLoc, "unknown graph type from <id> ")
727 << operands[0];
728 }
729 auto graphType = cast<GraphType>(type);
730 if (graphType.getNumResults() <= 0) {
731 return emitError(unknownLoc, "expected at least one result");
732 }
733
734 uint32_t graphID = operands[1];
735 if (graphMap.count(graphID)) {
736 return emitError(unknownLoc, "duplicate graph definition/declaration");
737 }
738
739 std::string graphName = getGraphSymbol(graphID);
740 auto graphOp =
741 spirv::GraphARMOp::create(opBuilder, unknownLoc, graphName, graphType);
742 curGraph = graphMap[graphID] = graphOp;
743 Block *entryBlock = graphOp.addEntryBlock();
744 LLVM_DEBUG({
745 logger.startLine()
746 << "//===-------------------------------------------===//\n";
747 logger.startLine() << "[graph] name: " << graphName << "\n";
748 logger.startLine() << "[graph] type: " << graphType << "\n";
749 logger.startLine() << "[graph] ID: " << graphID << "\n";
750 logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
751 logger.indent();
752 });
753
754 // Parse the op argument instructions.
755 for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
756 spirv::Opcode opcode;
757 ArrayRef<uint32_t> operands;
758 if (failed(sliceInstruction(opcode, operands,
759 spirv::Opcode::OpGraphInputARM))) {
760 return failure();
761 }
762 if (operands.size() != 3) {
763 return emitError(unknownLoc, "expected result type, result <id> and "
764 "input index for OpGraphInputARM");
765 }
766
767 Type argDefinedType = getType(operands[0]);
768 if (!argDefinedType) {
769 return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
770 }
771
772 if (argDefinedType != argType) {
773 return emitError(unknownLoc,
774 "mismatch in argument type between graph type "
775 "definition ")
776 << graphType << " and argument type definition " << argDefinedType
777 << " at argument " << index;
778 }
779 if (getValue(operands[1])) {
780 return emitError(unknownLoc, "duplicate definition of result <id> ")
781 << operands[1];
782 }
783
784 IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
785 if (!inputIndexAttr) {
786 return emitError(unknownLoc,
787 "unable to read inputIndex value from constant op ")
788 << operands[2];
789 }
790 BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
791 valueMap[operands[1]] = argValue;
792 }
793
794 graphOutputs.resize(graphType.getNumResults());
795
796 // RAII guard to reset the insertion point to the module's region after
797 // deserializing the body of this function.
798 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
799
800 blockMap[graphID] = entryBlock;
801 if (failed(createGraphBlock(graphID))) {
802 return failure();
803 }
804
805 // Process all the instructions in the graph until and including
806 // OpGraphEndARM.
807 spirv::Opcode opcode;
808 ArrayRef<uint32_t> instOperands;
809 do {
810 if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
811 return failure();
812 }
813
814 if (failed(processInstruction(opcode, instOperands))) {
815 return failure();
816 }
817 } while (opcode != spirv::Opcode::OpGraphEndARM);
818
819 return success();
820}
821
822LogicalResult
824 if (operands.size() != 2) {
825 return emitError(
826 unknownLoc,
827 "expected value id and output index for OpGraphSetOutputARM");
828 }
829
830 uint32_t id = operands[0];
831 Value value = getValue(id);
832 if (!value) {
833 return emitError(unknownLoc, "could not find result <id> ") << id;
834 }
835
836 IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
837 if (!outputIndexAttr) {
838 return emitError(unknownLoc,
839 "unable to read outputIndex value from constant op ")
840 << operands[1];
841 }
842 graphOutputs[outputIndexAttr.getInt()] = value;
843 return success();
844}
845
846LogicalResult
848 // Create GraphOutputsARM instruction.
849 spirv::GraphOutputsARMOp::create(opBuilder, unknownLoc, graphOutputs);
850
851 // Process OpGraphEndARM.
852 if (!operands.empty()) {
853 return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
854 }
855
856 curBlock = nullptr;
857 curGraph = std::nullopt;
858 graphOutputs.clear();
859
860 LLVM_DEBUG({
861 logger.unindent();
862 logger.startLine()
863 << "//===-------------------------------------------===//\n";
864 });
865 return success();
866}
867
868std::optional<std::pair<Attribute, Type>>
870 auto constIt = constantMap.find(id);
871 if (constIt == constantMap.end())
872 return std::nullopt;
873 return constIt->getSecond();
874}
875
876std::optional<std::pair<Attribute, Type>>
878 if (auto it = constantCompositeReplicateMap.find(id);
879 it != constantCompositeReplicateMap.end())
880 return it->second;
881 return std::nullopt;
882}
883
884std::optional<spirv::SpecConstOperationMaterializationInfo>
886 auto constIt = specConstOperationMap.find(id);
887 if (constIt == specConstOperationMap.end())
888 return std::nullopt;
889 return constIt->getSecond();
890}
891
893 auto funcName = nameMap.lookup(id).str();
894 if (funcName.empty()) {
895 funcName = "spirv_fn_" + std::to_string(id);
896 }
897 return funcName;
898}
899
900std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
901 std::string graphName = nameMap.lookup(id).str();
902 if (graphName.empty()) {
903 graphName = "spirv_graph_" + std::to_string(id);
904 }
905 return graphName;
906}
907
909 auto constName = nameMap.lookup(id).str();
910 if (constName.empty()) {
911 constName = "spirv_spec_const_" + std::to_string(id);
912 }
913 return constName;
914}
915
916spirv::SpecConstantOp
918 TypedAttr defaultValue) {
919 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
920 auto op = spirv::SpecConstantOp::create(opBuilder, unknownLoc, symName,
921 defaultValue);
922 if (decorations.count(resultID)) {
923 for (auto attr : decorations[resultID].getAttrs())
924 op->setAttr(attr.getName(), attr.getValue());
925 }
926 specConstMap[resultID] = op;
927 return op;
928}
929
930std::optional<spirv::GraphConstantARMOpMaterializationInfo>
932 auto graphConstIt = graphConstantMap.find(id);
933 if (graphConstIt == graphConstantMap.end())
934 return std::nullopt;
935 return graphConstIt->getSecond();
936}
937
938LogicalResult
940 unsigned wordIndex = 0;
941 if (operands.size() < 3) {
942 return emitError(
943 unknownLoc,
944 "OpVariable needs at least 3 operands, type, <id> and storage class");
945 }
946
947 // Result Type.
948 auto type = getType(operands[wordIndex]);
949 if (!type) {
950 return emitError(unknownLoc, "unknown result type <id> : ")
951 << operands[wordIndex];
952 }
953 auto ptrType = dyn_cast<spirv::PointerType>(type);
954 if (!ptrType) {
955 return emitError(unknownLoc,
956 "expected a result type <id> to be a spirv.ptr, found : ")
957 << type;
958 }
959 wordIndex++;
960
961 // Result <id>.
962 auto variableID = operands[wordIndex];
963 auto variableName = nameMap.lookup(variableID).str();
964 if (variableName.empty()) {
965 variableName = "spirv_var_" + std::to_string(variableID);
966 }
967 wordIndex++;
968
969 // Storage class.
970 auto storageClass = static_cast<spirv::StorageClass>(operands[wordIndex]);
971 if (ptrType.getStorageClass() != storageClass) {
972 return emitError(unknownLoc, "mismatch in storage class of pointer type ")
973 << type << " and that specified in OpVariable instruction : "
974 << stringifyStorageClass(storageClass);
975 }
976 wordIndex++;
977
978 // Initializer.
979 FlatSymbolRefAttr initializer = nullptr;
980
981 if (wordIndex < operands.size()) {
982 Operation *op = nullptr;
983
984 if (auto initOp = getGlobalVariable(operands[wordIndex]))
985 op = initOp;
986 else if (auto initOp = getSpecConstant(operands[wordIndex]))
987 op = initOp;
988 else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
989 op = initOp;
990 else
991 return emitError(unknownLoc, "unknown <id> ")
992 << operands[wordIndex] << "used as initializer";
993
994 initializer = SymbolRefAttr::get(op);
995 wordIndex++;
996 }
997 if (wordIndex != operands.size()) {
998 return emitError(unknownLoc,
999 "found more operands than expected when deserializing "
1000 "OpVariable instruction, only ")
1001 << wordIndex << " of " << operands.size() << " processed";
1002 }
1003 auto loc = createFileLineColLoc(opBuilder);
1004 auto varOp = spirv::GlobalVariableOp::create(
1005 opBuilder, loc, TypeAttr::get(type),
1006 opBuilder.getStringAttr(variableName), initializer);
1007
1008 // Decorations.
1009 if (decorations.count(variableID)) {
1010 for (auto attr : decorations[variableID].getAttrs())
1011 varOp->setAttr(attr.getName(), attr.getValue());
1012 }
1013 globalVariableMap[variableID] = varOp;
1014 return success();
1015}
1016
1017IntegerAttr spirv::Deserializer::getConstantInt(uint32_t id) {
1018 auto constInfo = getConstant(id);
1019 if (!constInfo) {
1020 return nullptr;
1021 }
1022 return dyn_cast<IntegerAttr>(constInfo->first);
1023}
1024
1025LogicalResult spirv::Deserializer::processName(ArrayRef<uint32_t> operands) {
1026 if (operands.size() < 2) {
1027 return emitError(unknownLoc, "OpName needs at least 2 operands");
1028 }
1029 if (!nameMap.lookup(operands[0]).empty()) {
1030 return emitError(unknownLoc, "duplicate name found for result <id> ")
1031 << operands[0];
1032 }
1033 unsigned wordIndex = 1;
1034 StringRef name = decodeStringLiteral(operands, wordIndex);
1035 if (wordIndex != operands.size()) {
1036 return emitError(unknownLoc,
1037 "unexpected trailing words in OpName instruction");
1038 }
1039 nameMap[operands[0]] = name;
1040 return success();
1041}
1042
1043//===----------------------------------------------------------------------===//
1044// Type
1045//===----------------------------------------------------------------------===//
1046
1047LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
1048 ArrayRef<uint32_t> operands) {
1049 if (operands.empty()) {
1050 return emitError(unknownLoc, "type instruction with opcode ")
1051 << spirv::stringifyOpcode(opcode) << " needs at least one <id>";
1052 }
1053
1054 /// TODO: Types might be forward declared in some instructions and need to be
1055 /// handled appropriately.
1056 if (typeMap.count(operands[0])) {
1057 return emitError(unknownLoc, "duplicate definition for result <id> ")
1058 << operands[0];
1059 }
1060
1061 switch (opcode) {
1062 case spirv::Opcode::OpTypeVoid:
1063 if (operands.size() != 1)
1064 return emitError(unknownLoc, "OpTypeVoid must have no parameters");
1065 typeMap[operands[0]] = opBuilder.getNoneType();
1066 break;
1067 case spirv::Opcode::OpTypeBool:
1068 if (operands.size() != 1)
1069 return emitError(unknownLoc, "OpTypeBool must have no parameters");
1070 typeMap[operands[0]] = opBuilder.getI1Type();
1071 break;
1072 case spirv::Opcode::OpTypeInt: {
1073 if (operands.size() != 3)
1074 return emitError(
1075 unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
1076
1077 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
1078 // to preserve or validate.
1079 // 0 indicates unsigned, or no signedness semantics
1080 // 1 indicates signed semantics."
1081 //
1082 // So we cannot differentiate signless and unsigned integers; always use
1083 // signless semantics for such cases.
1084 auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
1085 : IntegerType::SignednessSemantics::Signless;
1086 typeMap[operands[0]] = IntegerType::get(context, operands[1], sign);
1087 } break;
1088 case spirv::Opcode::OpTypeFloat: {
1089 if (operands.size() != 2 && operands.size() != 3)
1090 return emitError(unknownLoc,
1091 "OpTypeFloat expects either 2 operands (type, bitwidth) "
1092 "or 3 operands (type, bitwidth, encoding), but got ")
1093 << operands.size();
1094 uint32_t bitWidth = operands[1];
1095
1096 Type floatTy;
1097 switch (bitWidth) {
1098 case 16:
1099 floatTy = opBuilder.getF16Type();
1100 break;
1101 case 32:
1102 floatTy = opBuilder.getF32Type();
1103 break;
1104 case 64:
1105 floatTy = opBuilder.getF64Type();
1106 break;
1107 default:
1108 return emitError(unknownLoc, "unsupported OpTypeFloat bitwidth: ")
1109 << bitWidth;
1110 }
1111
1112 if (operands.size() == 3) {
1113 if (spirv::FPEncoding(operands[2]) != spirv::FPEncoding::BFloat16KHR)
1114 return emitError(unknownLoc, "unsupported OpTypeFloat FP encoding: ")
1115 << operands[2];
1116 if (bitWidth != 16)
1117 return emitError(unknownLoc,
1118 "invalid OpTypeFloat bitwidth for bfloat16 encoding: ")
1119 << bitWidth << " (expected 16)";
1120 floatTy = opBuilder.getBF16Type();
1121 }
1122
1123 typeMap[operands[0]] = floatTy;
1124 } break;
1125 case spirv::Opcode::OpTypeVector: {
1126 if (operands.size() != 3) {
1127 return emitError(
1128 unknownLoc,
1129 "OpTypeVector must have element type and count parameters");
1130 }
1131 Type elementTy = getType(operands[1]);
1132 if (!elementTy) {
1133 return emitError(unknownLoc, "OpTypeVector references undefined <id> ")
1134 << operands[1];
1135 }
1136 typeMap[operands[0]] = VectorType::get({operands[2]}, elementTy);
1137 } break;
1138 case spirv::Opcode::OpTypePointer: {
1139 return processOpTypePointer(operands);
1140 } break;
1141 case spirv::Opcode::OpTypeArray:
1142 return processArrayType(operands);
1143 case spirv::Opcode::OpTypeCooperativeMatrixKHR:
1144 return processCooperativeMatrixTypeKHR(operands);
1145 case spirv::Opcode::OpTypeFunction:
1146 return processFunctionType(operands);
1147 case spirv::Opcode::OpTypeImage:
1148 return processImageType(operands);
1149 case spirv::Opcode::OpTypeSampledImage:
1150 return processSampledImageType(operands);
1151 case spirv::Opcode::OpTypeRuntimeArray:
1152 return processRuntimeArrayType(operands);
1153 case spirv::Opcode::OpTypeStruct:
1154 return processStructType(operands);
1155 case spirv::Opcode::OpTypeMatrix:
1156 return processMatrixType(operands);
1157 case spirv::Opcode::OpTypeTensorARM:
1158 return processTensorARMType(operands);
1159 case spirv::Opcode::OpTypeGraphARM:
1160 return processGraphTypeARM(operands);
1161 default:
1162 return emitError(unknownLoc, "unhandled type instruction");
1163 }
1164 return success();
1165}
1166
1167LogicalResult
1169 if (operands.size() != 3)
1170 return emitError(unknownLoc, "OpTypePointer must have two parameters");
1171
1172 auto pointeeType = getType(operands[2]);
1173 if (!pointeeType)
1174 return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> ")
1175 << operands[2];
1176
1177 uint32_t typePointerID = operands[0];
1178 auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
1179 typeMap[typePointerID] = spirv::PointerType::get(pointeeType, storageClass);
1180
1181 for (auto *deferredStructIt = std::begin(deferredStructTypesInfos);
1182 deferredStructIt != std::end(deferredStructTypesInfos);) {
1183 for (auto *unresolvedMemberIt =
1184 std::begin(deferredStructIt->unresolvedMemberTypes);
1185 unresolvedMemberIt !=
1186 std::end(deferredStructIt->unresolvedMemberTypes);) {
1187 if (unresolvedMemberIt->first == typePointerID) {
1188 // The newly constructed pointer type can resolve one of the
1189 // deferred struct type members; update the memberTypes list and
1190 // clean the unresolvedMemberTypes list accordingly.
1191 deferredStructIt->memberTypes[unresolvedMemberIt->second] =
1192 typeMap[typePointerID];
1193 unresolvedMemberIt =
1194 deferredStructIt->unresolvedMemberTypes.erase(unresolvedMemberIt);
1195 } else {
1196 ++unresolvedMemberIt;
1197 }
1198 }
1199
1200 if (deferredStructIt->unresolvedMemberTypes.empty()) {
1201 // All deferred struct type members are now resolved, set the struct body.
1202 auto structType = deferredStructIt->deferredStructType;
1203
1204 assert(structType && "expected a spirv::StructType");
1205 assert(structType.isIdentified() && "expected an indentified struct");
1206
1207 if (failed(structType.trySetBody(
1208 deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
1209 deferredStructIt->memberDecorationsInfo,
1210 deferredStructIt->structDecorationsInfo)))
1211 return failure();
1212
1213 deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
1214 } else {
1215 ++deferredStructIt;
1216 }
1217 }
1218
1219 return success();
1220}
1221
1222LogicalResult
1224 if (operands.size() != 3) {
1225 return emitError(unknownLoc,
1226 "OpTypeArray must have element type and count parameters");
1227 }
1228
1229 Type elementTy = getType(operands[1]);
1230 if (!elementTy) {
1231 return emitError(unknownLoc, "OpTypeArray references undefined <id> ")
1232 << operands[1];
1233 }
1234
1235 unsigned count = 0;
1236 // TODO: The count can also come frome a specialization constant.
1237 auto countInfo = getConstant(operands[2]);
1238 if (!countInfo) {
1239 return emitError(unknownLoc, "OpTypeArray count <id> ")
1240 << operands[2] << "can only come from normal constant right now";
1241 }
1242
1243 if (auto intVal = dyn_cast<IntegerAttr>(countInfo->first)) {
1244 count = intVal.getValue().getZExtValue();
1245 } else {
1246 return emitError(unknownLoc, "OpTypeArray count must come from a "
1247 "scalar integer constant instruction");
1248 }
1249
1250 typeMap[operands[0]] = spirv::ArrayType::get(
1251 elementTy, count, typeDecorations.lookup(operands[0]));
1252 return success();
1253}
1254
1255LogicalResult
1257 assert(!operands.empty() && "No operands for processing function type");
1258 if (operands.size() == 1) {
1259 return emitError(unknownLoc, "missing return type for OpTypeFunction");
1260 }
1261 auto returnType = getType(operands[1]);
1262 if (!returnType) {
1263 return emitError(unknownLoc, "unknown return type in OpTypeFunction");
1264 }
1265 SmallVector<Type, 1> argTypes;
1266 for (size_t i = 2, e = operands.size(); i < e; ++i) {
1267 auto ty = getType(operands[i]);
1268 if (!ty) {
1269 return emitError(unknownLoc, "unknown argument type in OpTypeFunction");
1270 }
1271 argTypes.push_back(ty);
1272 }
1273 ArrayRef<Type> returnTypes;
1274 if (!isVoidType(returnType)) {
1275 returnTypes = llvm::ArrayRef(returnType);
1276 }
1277 typeMap[operands[0]] = FunctionType::get(context, argTypes, returnTypes);
1278 return success();
1279}
1280
1282 ArrayRef<uint32_t> operands) {
1283 if (operands.size() != 6) {
1284 return emitError(unknownLoc,
1285 "OpTypeCooperativeMatrixKHR must have element type, "
1286 "scope, row and column parameters, and use");
1287 }
1288
1289 Type elementTy = getType(operands[1]);
1290 if (!elementTy) {
1291 return emitError(unknownLoc,
1292 "OpTypeCooperativeMatrixKHR references undefined <id> ")
1293 << operands[1];
1294 }
1295
1296 std::optional<spirv::Scope> scope =
1297 spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
1298 if (!scope) {
1299 return emitError(
1300 unknownLoc,
1301 "OpTypeCooperativeMatrixKHR references undefined scope <id> ")
1302 << operands[2];
1303 }
1304
1305 IntegerAttr rowsAttr = getConstantInt(operands[3]);
1306 IntegerAttr columnsAttr = getConstantInt(operands[4]);
1307 IntegerAttr useAttr = getConstantInt(operands[5]);
1308
1309 if (!rowsAttr)
1310 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Rows` references "
1311 "undefined constant <id> ")
1312 << operands[3];
1313
1314 if (!columnsAttr)
1315 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Columns` "
1316 "references undefined constant <id> ")
1317 << operands[4];
1318
1319 if (!useAttr)
1320 return emitError(unknownLoc, "OpTypeCooperativeMatrixKHR `Use` references "
1321 "undefined constant <id> ")
1322 << operands[5];
1323
1324 unsigned rows = rowsAttr.getInt();
1325 unsigned columns = columnsAttr.getInt();
1326
1327 std::optional<spirv::CooperativeMatrixUseKHR> use =
1328 spirv::symbolizeCooperativeMatrixUseKHR(useAttr.getInt());
1329 if (!use) {
1330 return emitError(
1331 unknownLoc,
1332 "OpTypeCooperativeMatrixKHR references undefined use <id> ")
1333 << operands[5];
1334 }
1335
1336 typeMap[operands[0]] =
1337 spirv::CooperativeMatrixType::get(elementTy, rows, columns, *scope, *use);
1338 return success();
1339}
1340
1341LogicalResult
1343 if (operands.size() != 2) {
1344 return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
1345 }
1346 Type memberType = getType(operands[1]);
1347 if (!memberType) {
1348 return emitError(unknownLoc,
1349 "OpTypeRuntimeArray references undefined <id> ")
1350 << operands[1];
1351 }
1352 typeMap[operands[0]] = spirv::RuntimeArrayType::get(
1353 memberType, typeDecorations.lookup(operands[0]));
1354 return success();
1355}
1356
1357LogicalResult
1359 // TODO: Find a way to handle identified structs when debug info is stripped.
1360
1361 if (operands.empty()) {
1362 return emitError(unknownLoc, "OpTypeStruct must have at least result <id>");
1363 }
1364
1365 if (operands.size() == 1) {
1366 // Handle empty struct.
1367 typeMap[operands[0]] =
1368 spirv::StructType::getEmpty(context, nameMap.lookup(operands[0]).str());
1369 return success();
1370 }
1371
1372 // First element is operand ID, second element is member index in the struct.
1373 SmallVector<std::pair<uint32_t, unsigned>, 0> unresolvedMemberTypes;
1374 SmallVector<Type, 4> memberTypes;
1375
1376 for (auto op : llvm::drop_begin(operands, 1)) {
1377 Type memberType = getType(op);
1378 bool typeForwardPtr = (typeForwardPointerIDs.count(op) != 0);
1379
1380 if (!memberType && !typeForwardPtr)
1381 return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
1382 << op;
1383
1384 if (!memberType)
1385 unresolvedMemberTypes.emplace_back(op, memberTypes.size());
1386
1387 memberTypes.push_back(memberType);
1388 }
1389
1392 if (memberDecorationMap.count(operands[0])) {
1393 auto &allMemberDecorations = memberDecorationMap[operands[0]];
1394 for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
1395 if (allMemberDecorations.count(memberIndex)) {
1396 for (auto &memberDecoration : allMemberDecorations[memberIndex]) {
1397 // Check for offset.
1398 if (memberDecoration.first == spirv::Decoration::Offset) {
1399 // If offset info is empty, resize to the number of members;
1400 if (offsetInfo.empty()) {
1401 offsetInfo.resize(memberTypes.size());
1402 }
1403 offsetInfo[memberIndex] = memberDecoration.second[0];
1404 } else {
1405 auto intType = mlir::IntegerType::get(context, 32);
1406 if (!memberDecoration.second.empty()) {
1407 memberDecorationsInfo.emplace_back(
1408 memberIndex, memberDecoration.first,
1409 IntegerAttr::get(intType, memberDecoration.second[0]));
1410 } else {
1411 memberDecorationsInfo.emplace_back(
1412 memberIndex, memberDecoration.first, UnitAttr::get(context));
1413 }
1414 }
1415 }
1416 }
1417 }
1418 }
1419
1421 if (decorations.count(operands[0])) {
1422 NamedAttrList &allDecorations = decorations[operands[0]];
1423 for (NamedAttribute &decorationAttr : allDecorations) {
1424 std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
1425 llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
1426 assert(decoration.has_value());
1427 structDecorationsInfo.emplace_back(decoration.value(),
1428 decorationAttr.getValue());
1429 }
1430 }
1431
1432 uint32_t structID = operands[0];
1433 std::string structIdentifier = nameMap.lookup(structID).str();
1434
1435 if (structIdentifier.empty()) {
1436 assert(unresolvedMemberTypes.empty() &&
1437 "didn't expect unresolved member types");
1438 typeMap[structID] = spirv::StructType::get(
1439 memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
1440 } else {
1441 auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
1442 typeMap[structID] = structTy;
1443
1444 if (!unresolvedMemberTypes.empty())
1445 deferredStructTypesInfos.push_back(
1446 {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
1447 memberDecorationsInfo, structDecorationsInfo});
1448 else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
1449 memberDecorationsInfo,
1450 structDecorationsInfo)))
1451 return failure();
1452 }
1453
1454 // TODO: Update StructType to have member name as attribute as
1455 // well.
1456 return success();
1457}
1458
1459LogicalResult
1461 if (operands.size() != 3) {
1462 // Three operands are needed: result_id, column_type, and column_count
1463 return emitError(unknownLoc, "OpTypeMatrix must have 3 operands"
1464 " (result_id, column_type, and column_count)");
1465 }
1466 // Matrix columns must be of vector type
1467 Type elementTy = getType(operands[1]);
1468 if (!elementTy) {
1469 return emitError(unknownLoc,
1470 "OpTypeMatrix references undefined column type.")
1471 << operands[1];
1472 }
1473
1474 uint32_t colsCount = operands[2];
1475 typeMap[operands[0]] = spirv::MatrixType::get(elementTy, colsCount);
1476 return success();
1477}
1478
1479LogicalResult
1481 unsigned size = operands.size();
1482 if (size < 2 || size > 4)
1483 return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands "
1484 "(result_id, element_type, (rank), (shape)) ")
1485 << size;
1486
1487 Type elementTy = getType(operands[1]);
1488 if (!elementTy)
1489 return emitError(unknownLoc,
1490 "OpTypeTensorARM references undefined element type ")
1491 << operands[1];
1492
1493 if (size == 2) {
1494 typeMap[operands[0]] = TensorArmType::get({}, elementTy);
1495 return success();
1496 }
1497
1498 IntegerAttr rankAttr = getConstantInt(operands[2]);
1499 if (!rankAttr)
1500 return emitError(unknownLoc, "OpTypeTensorARM rank must come from a "
1501 "scalar integer constant instruction");
1502 unsigned rank = rankAttr.getValue().getZExtValue();
1503 if (size == 3) {
1504 SmallVector<int64_t, 4> shape(rank, ShapedType::kDynamic);
1505 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1506 return success();
1507 }
1508
1509 std::optional<std::pair<Attribute, Type>> shapeInfo =
1510 getConstant(operands[3]);
1511 if (!shapeInfo)
1512 return emitError(unknownLoc, "OpTypeTensorARM shape must come from a "
1513 "constant instruction of type OpTypeArray");
1514
1515 ArrayAttr shapeArrayAttr = dyn_cast<ArrayAttr>(shapeInfo->first);
1517 for (auto dimAttr : shapeArrayAttr.getValue()) {
1518 auto dimIntAttr = dyn_cast<IntegerAttr>(dimAttr);
1519 if (!dimIntAttr)
1520 return emitError(unknownLoc, "OpTypeTensorARM shape has an invalid "
1521 "dimension size");
1522 shape.push_back(dimIntAttr.getValue().getSExtValue());
1523 }
1524 typeMap[operands[0]] = TensorArmType::get(shape, elementTy);
1525 return success();
1526}
1527
1528LogicalResult
1530 unsigned size = operands.size();
1531 if (size < 2) {
1532 return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1533 "(result_id, num_inputs, (inout0_type, "
1534 "inout1_type, ...))")
1535 << size;
1536 }
1537 uint32_t numInputs = operands[1];
1538 SmallVector<Type, 1> argTypes;
1539 SmallVector<Type, 1> returnTypes;
1540 for (unsigned i = 2; i < size; ++i) {
1541 Type inOutTy = getType(operands[i]);
1542 if (!inOutTy) {
1543 return emitError(unknownLoc,
1544 "OpTypeGraphARM references undefined element type.")
1545 << operands[i];
1546 }
1547 if (i - 2 >= numInputs) {
1548 returnTypes.push_back(inOutTy);
1549 } else {
1550 argTypes.push_back(inOutTy);
1551 }
1552 }
1553 typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1554 return success();
1555}
1556
1557LogicalResult
1559 if (operands.size() != 2)
1560 return emitError(unknownLoc,
1561 "OpTypeForwardPointer instruction must have two operands");
1562
1563 typeForwardPointerIDs.insert(operands[0]);
1564 // TODO: Use the 2nd operand (Storage Class) to validate the OpTypePointer
1565 // instruction that defines the actual type.
1566
1567 return success();
1568}
1569
1570LogicalResult
1572 // TODO: Add support for Access Qualifier.
1573 if (operands.size() != 8)
1574 return emitError(
1575 unknownLoc,
1576 "OpTypeImage with non-eight operands are not supported yet");
1577
1578 Type elementTy = getType(operands[1]);
1579 if (!elementTy)
1580 return emitError(unknownLoc, "OpTypeImage references undefined <id>: ")
1581 << operands[1];
1582
1583 auto dim = spirv::symbolizeDim(operands[2]);
1584 if (!dim)
1585 return emitError(unknownLoc, "unknown Dim for OpTypeImage: ")
1586 << operands[2];
1587
1588 auto depthInfo = spirv::symbolizeImageDepthInfo(operands[3]);
1589 if (!depthInfo)
1590 return emitError(unknownLoc, "unknown Depth for OpTypeImage: ")
1591 << operands[3];
1592
1593 auto arrayedInfo = spirv::symbolizeImageArrayedInfo(operands[4]);
1594 if (!arrayedInfo)
1595 return emitError(unknownLoc, "unknown Arrayed for OpTypeImage: ")
1596 << operands[4];
1597
1598 auto samplingInfo = spirv::symbolizeImageSamplingInfo(operands[5]);
1599 if (!samplingInfo)
1600 return emitError(unknownLoc, "unknown MS for OpTypeImage: ") << operands[5];
1601
1602 auto samplerUseInfo = spirv::symbolizeImageSamplerUseInfo(operands[6]);
1603 if (!samplerUseInfo)
1604 return emitError(unknownLoc, "unknown Sampled for OpTypeImage: ")
1605 << operands[6];
1606
1607 auto format = spirv::symbolizeImageFormat(operands[7]);
1608 if (!format)
1609 return emitError(unknownLoc, "unknown Format for OpTypeImage: ")
1610 << operands[7];
1611
1612 typeMap[operands[0]] = spirv::ImageType::get(
1613 elementTy, dim.value(), depthInfo.value(), arrayedInfo.value(),
1614 samplingInfo.value(), samplerUseInfo.value(), format.value());
1615 return success();
1616}
1617
1618LogicalResult
1620 if (operands.size() != 2)
1621 return emitError(unknownLoc, "OpTypeSampledImage must have two operands");
1622
1623 Type elementTy = getType(operands[1]);
1624 if (!elementTy)
1625 return emitError(unknownLoc,
1626 "OpTypeSampledImage references undefined <id>: ")
1627 << operands[1];
1628
1629 typeMap[operands[0]] = spirv::SampledImageType::get(elementTy);
1630 return success();
1631}
1632
1633//===----------------------------------------------------------------------===//
1634// Constant
1635//===----------------------------------------------------------------------===//
1636
1638 bool isSpec) {
1639 StringRef opname = isSpec ? "OpSpecConstant" : "OpConstant";
1640
1641 if (operands.size() < 2) {
1642 return emitError(unknownLoc)
1643 << opname << " must have type <id> and result <id>";
1644 }
1645 if (operands.size() < 3) {
1646 return emitError(unknownLoc)
1647 << opname << " must have at least 1 more parameter";
1648 }
1649
1650 Type resultType = getType(operands[0]);
1651 if (!resultType) {
1652 return emitError(unknownLoc, "undefined result type from <id> ")
1653 << operands[0];
1654 }
1655
1656 auto checkOperandSizeForBitwidth = [&](unsigned bitwidth) -> LogicalResult {
1657 if (bitwidth == 64) {
1658 if (operands.size() == 4) {
1659 return success();
1660 }
1661 return emitError(unknownLoc)
1662 << opname << " should have 2 parameters for 64-bit values";
1663 }
1664 if (bitwidth <= 32) {
1665 if (operands.size() == 3) {
1666 return success();
1667 }
1668
1669 return emitError(unknownLoc)
1670 << opname
1671 << " should have 1 parameter for values with no more than 32 bits";
1672 }
1673 return emitError(unknownLoc, "unsupported OpConstant bitwidth: ")
1674 << bitwidth;
1675 };
1676
1677 auto resultID = operands[1];
1678
1679 if (auto intType = dyn_cast<IntegerType>(resultType)) {
1680 auto bitwidth = intType.getWidth();
1681 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1682 return failure();
1683 }
1684
1685 APInt value;
1686 if (bitwidth == 64) {
1687 // 64-bit integers are represented with two SPIR-V words. According to
1688 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1689 // literal’s low-order words appear first."
1690 struct DoubleWord {
1691 uint32_t word1;
1692 uint32_t word2;
1693 } words = {operands[2], operands[3]};
1694 value = APInt(64, llvm::bit_cast<uint64_t>(words), /*isSigned=*/true);
1695 } else if (bitwidth <= 32) {
1696 value = APInt(bitwidth, operands[2], /*isSigned=*/true,
1697 /*implicitTrunc=*/true);
1698 }
1699
1700 auto attr = opBuilder.getIntegerAttr(intType, value);
1701
1702 if (isSpec) {
1703 createSpecConstant(unknownLoc, resultID, attr);
1704 } else {
1705 // For normal constants, we just record the attribute (and its type) for
1706 // later materialization at use sites.
1707 constantMap.try_emplace(resultID, attr, intType);
1708 }
1709
1710 return success();
1711 }
1712
1713 if (auto floatType = dyn_cast<FloatType>(resultType)) {
1714 auto bitwidth = floatType.getWidth();
1715 if (failed(checkOperandSizeForBitwidth(bitwidth))) {
1716 return failure();
1717 }
1718
1719 APFloat value(0.f);
1720 if (floatType.isF64()) {
1721 // Double values are represented with two SPIR-V words. According to
1722 // SPIR-V spec: "When the type’s bit width is larger than one word, the
1723 // literal’s low-order words appear first."
1724 struct DoubleWord {
1725 uint32_t word1;
1726 uint32_t word2;
1727 } words = {operands[2], operands[3]};
1728 value = APFloat(llvm::bit_cast<double>(words));
1729 } else if (floatType.isF32()) {
1730 value = APFloat(llvm::bit_cast<float>(operands[2]));
1731 } else if (floatType.isF16()) {
1732 APInt data(16, operands[2]);
1733 value = APFloat(APFloat::IEEEhalf(), data);
1734 } else if (floatType.isBF16()) {
1735 APInt data(16, operands[2]);
1736 value = APFloat(APFloat::BFloat(), data);
1737 }
1738
1739 auto attr = opBuilder.getFloatAttr(floatType, value);
1740 if (isSpec) {
1741 createSpecConstant(unknownLoc, resultID, attr);
1742 } else {
1743 // For normal constants, we just record the attribute (and its type) for
1744 // later materialization at use sites.
1745 constantMap.try_emplace(resultID, attr, floatType);
1746 }
1747
1748 return success();
1749 }
1750
1751 return emitError(unknownLoc, "OpConstant can only generate values of "
1752 "scalar integer or floating-point type");
1753}
1754
1756 bool isTrue, ArrayRef<uint32_t> operands, bool isSpec) {
1757 if (operands.size() != 2) {
1758 return emitError(unknownLoc, "Op")
1759 << (isSpec ? "Spec" : "") << "Constant"
1760 << (isTrue ? "True" : "False")
1761 << " must have type <id> and result <id>";
1762 }
1763
1764 auto attr = opBuilder.getBoolAttr(isTrue);
1765 auto resultID = operands[1];
1766 if (isSpec) {
1767 createSpecConstant(unknownLoc, resultID, attr);
1768 } else {
1769 // For normal constants, we just record the attribute (and its type) for
1770 // later materialization at use sites.
1771 constantMap.try_emplace(resultID, attr, opBuilder.getI1Type());
1772 }
1773
1774 return success();
1775}
1776
1777LogicalResult
1779 if (operands.size() < 2) {
1780 return emitError(unknownLoc,
1781 "OpConstantComposite must have type <id> and result <id>");
1782 }
1783 if (operands.size() < 3) {
1784 return emitError(unknownLoc,
1785 "OpConstantComposite must have at least 1 parameter");
1786 }
1787
1788 Type resultType = getType(operands[0]);
1789 if (!resultType) {
1790 return emitError(unknownLoc, "undefined result type from <id> ")
1791 << operands[0];
1792 }
1793
1795 elements.reserve(operands.size() - 2);
1796 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1797 auto elementInfo = getConstant(operands[i]);
1798 if (!elementInfo) {
1799 return emitError(unknownLoc, "OpConstantComposite component <id> ")
1800 << operands[i] << " must come from a normal constant";
1801 }
1802 elements.push_back(elementInfo->first);
1803 }
1804
1805 auto resultID = operands[1];
1806 if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
1807 SmallVector<Attribute> flattenedElems;
1808 for (Attribute element : elements) {
1809 if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
1810 for (auto value : denseElemAttr.getValues<Attribute>())
1811 flattenedElems.push_back(value);
1812 } else {
1813 flattenedElems.push_back(element);
1814 }
1815 }
1816 auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
1817 constantMap.try_emplace(resultID, attr, tensorType);
1818 } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1819 auto attr = DenseElementsAttr::get(shapedType, elements);
1820 // For normal constants, we just record the attribute (and its type) for
1821 // later materialization at use sites.
1822 constantMap.try_emplace(resultID, attr, shapedType);
1823 } else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
1824 auto attr = opBuilder.getArrayAttr(elements);
1825 constantMap.try_emplace(resultID, attr, resultType);
1826 } else {
1827 return emitError(unknownLoc, "unsupported OpConstantComposite type: ")
1828 << resultType;
1829 }
1830
1831 return success();
1832}
1833
1835 ArrayRef<uint32_t> operands) {
1836 if (operands.size() != 3) {
1837 return emitError(
1838 unknownLoc,
1839 "OpConstantCompositeReplicateEXT expects 3 operands but found ")
1840 << operands.size();
1841 }
1842
1843 Type resultType = getType(operands[0]);
1844 if (!resultType) {
1845 return emitError(unknownLoc, "undefined result type from <id> ")
1846 << operands[0];
1847 }
1848
1849 auto compositeType = dyn_cast<CompositeType>(resultType);
1850 if (!compositeType) {
1851 return emitError(unknownLoc,
1852 "result type from <id> is not a composite type")
1853 << operands[0];
1854 }
1855
1856 uint32_t resultID = operands[1];
1857 uint32_t constantID = operands[2];
1858
1859 std::optional<std::pair<Attribute, Type>> constantInfo =
1860 getConstant(constantID);
1861 if (constantInfo.has_value()) {
1862 constantCompositeReplicateMap.try_emplace(
1863 resultID, constantInfo.value().first, resultType);
1864 return success();
1865 }
1866
1867 std::optional<std::pair<Attribute, Type>> replicatedConstantCompositeInfo =
1869 if (replicatedConstantCompositeInfo.has_value()) {
1870 constantCompositeReplicateMap.try_emplace(
1871 resultID, replicatedConstantCompositeInfo.value().first, resultType);
1872 return success();
1873 }
1874
1875 return emitError(unknownLoc, "OpConstantCompositeReplicateEXT operand <id> ")
1876 << constantID
1877 << " must come from a normal constant or a "
1878 "OpConstantCompositeReplicateEXT";
1879}
1880
1881LogicalResult
1883 if (operands.size() < 2) {
1884 return emitError(
1885 unknownLoc,
1886 "OpSpecConstantComposite must have type <id> and result <id>");
1887 }
1888 if (operands.size() < 3) {
1889 return emitError(unknownLoc,
1890 "OpSpecConstantComposite must have at least 1 parameter");
1891 }
1892
1893 Type resultType = getType(operands[0]);
1894 if (!resultType) {
1895 return emitError(unknownLoc, "undefined result type from <id> ")
1896 << operands[0];
1897 }
1898
1899 auto resultID = operands[1];
1900 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1901
1903 elements.reserve(operands.size() - 2);
1904 for (unsigned i = 2, e = operands.size(); i < e; ++i) {
1905 auto elementInfo = getSpecConstant(operands[i]);
1906 elements.push_back(SymbolRefAttr::get(elementInfo));
1907 }
1908
1909 auto op = spirv::SpecConstantCompositeOp::create(
1910 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1911 opBuilder.getArrayAttr(elements));
1912 specConstCompositeMap[resultID] = op;
1913
1914 return success();
1915}
1916
1918 ArrayRef<uint32_t> operands) {
1919 if (operands.size() != 3) {
1920 return emitError(unknownLoc, "OpSpecConstantCompositeReplicateEXT expects "
1921 "3 operands but found ")
1922 << operands.size();
1923 }
1924
1925 Type resultType = getType(operands[0]);
1926 if (!resultType) {
1927 return emitError(unknownLoc, "undefined result type from <id> ")
1928 << operands[0];
1929 }
1930
1931 auto compositeType = dyn_cast<CompositeType>(resultType);
1932 if (!compositeType) {
1933 return emitError(unknownLoc,
1934 "result type from <id> is not a composite type")
1935 << operands[0];
1936 }
1937
1938 uint32_t resultID = operands[1];
1939
1940 auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
1941 spirv::SpecConstantOp constituentSpecConstantOp =
1942 getSpecConstant(operands[2]);
1943 auto op = spirv::EXTSpecConstantCompositeReplicateOp::create(
1944 opBuilder, unknownLoc, TypeAttr::get(resultType), symName,
1945 SymbolRefAttr::get(constituentSpecConstantOp));
1946
1947 specConstCompositeReplicateMap[resultID] = op;
1948
1949 return success();
1950}
1951
1952LogicalResult
1954 if (operands.size() < 3)
1955 return emitError(unknownLoc, "OpConstantOperation must have type <id>, "
1956 "result <id>, and operand opcode");
1957
1958 uint32_t resultTypeID = operands[0];
1959
1960 if (!getType(resultTypeID))
1961 return emitError(unknownLoc, "undefined result type from <id> ")
1962 << resultTypeID;
1963
1964 uint32_t resultID = operands[1];
1965 spirv::Opcode enclosedOpcode = static_cast<spirv::Opcode>(operands[2]);
1966 auto emplaceResult = specConstOperationMap.try_emplace(
1967 resultID,
1969 enclosedOpcode, resultTypeID,
1970 SmallVector<uint32_t>{operands.begin() + 3, operands.end()}});
1971
1972 if (!emplaceResult.second)
1973 return emitError(unknownLoc, "value with <id>: ")
1974 << resultID << " is probably defined before.";
1975
1976 return success();
1977}
1978
1980 uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID,
1981 ArrayRef<uint32_t> enclosedOpOperands) {
1982
1983 Type resultType = getType(resultTypeID);
1984
1985 // Instructions wrapped by OpSpecConstantOp need an ID for their
1986 // Deserializer::processOp<op_name>(...) to emit the corresponding SPIR-V
1987 // dialect wrapped op. For that purpose, a new value map is created and "fake"
1988 // ID in that map is assigned to the result of the enclosed instruction. Note
1989 // that there is no need to update this fake ID since we only need to
1990 // reference the created Value for the enclosed op from the spv::YieldOp
1991 // created later in this method (both of which are the only values in their
1992 // region: the SpecConstantOperation's region). If we encounter another
1993 // SpecConstantOperation in the module, we simply re-use the fake ID since the
1994 // previous Value assigned to it isn't visible in the current scope anyway.
1995 DenseMap<uint32_t, Value> newValueMap;
1996 llvm::SaveAndRestore valueMapGuard(valueMap, newValueMap);
1997 constexpr uint32_t fakeID = static_cast<uint32_t>(-3);
1998
1999 SmallVector<uint32_t, 4> enclosedOpResultTypeAndOperands;
2000 enclosedOpResultTypeAndOperands.push_back(resultTypeID);
2001 enclosedOpResultTypeAndOperands.push_back(fakeID);
2002 enclosedOpResultTypeAndOperands.append(enclosedOpOperands.begin(),
2003 enclosedOpOperands.end());
2004
2005 // Process enclosed instruction before creating the enclosing
2006 // specConstantOperation (and its region). This way, references to constants,
2007 // global variables, and spec constants will be materialized outside the new
2008 // op's region. For more info, see Deserializer::getValue's implementation.
2009 if (failed(
2010 processInstruction(enclosedOpcode, enclosedOpResultTypeAndOperands)))
2011 return Value();
2012
2013 // Since the enclosed op is emitted in the current block, split it in a
2014 // separate new block.
2015 Block *enclosedBlock = curBlock->splitBlock(&curBlock->back());
2016
2017 auto loc = createFileLineColLoc(opBuilder);
2018 auto specConstOperationOp =
2019 spirv::SpecConstantOperationOp::create(opBuilder, loc, resultType);
2020
2021 Region &body = specConstOperationOp.getBody();
2022 // Move the new block into SpecConstantOperation's body.
2023 body.getBlocks().splice(body.end(), curBlock->getParent()->getBlocks(),
2024 Region::iterator(enclosedBlock));
2025 Block &block = body.back();
2026
2027 // RAII guard to reset the insertion point to the module's region after
2028 // deserializing the body of the specConstantOperation.
2029 OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
2030 opBuilder.setInsertionPointToEnd(&block);
2031
2032 spirv::YieldOp::create(opBuilder, loc, block.front().getResult(0));
2033 return specConstOperationOp.getResult();
2034}
2035
2036LogicalResult
2038 if (operands.size() != 2) {
2039 return emitError(unknownLoc,
2040 "OpConstantNull must only have type <id> and result <id>");
2041 }
2042
2043 Type resultType = getType(operands[0]);
2044 if (!resultType) {
2045 return emitError(unknownLoc, "undefined result type from <id> ")
2046 << operands[0];
2047 }
2048
2049 auto resultID = operands[1];
2050 Attribute attr;
2051 if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
2052 attr = opBuilder.getZeroAttr(resultType);
2053 } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
2054 if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
2055 attr = DenseElementsAttr::get(tensorType, element);
2056 }
2057
2058 if (attr) {
2059 // For normal constants, we just record the attribute (and its type) for
2060 // later materialization at use sites.
2061 constantMap.try_emplace(resultID, attr, resultType);
2062 return success();
2063 }
2064
2065 return emitError(unknownLoc, "unsupported OpConstantNull type: ")
2066 << resultType;
2067}
2068
2069LogicalResult
2071 if (operands.size() < 3) {
2072 return emitError(unknownLoc)
2073 << "OpGraphConstantARM must have at least 2 operands";
2074 }
2075
2076 Type resultType = getType(operands[0]);
2077 if (!resultType) {
2078 return emitError(unknownLoc, "undefined result type from <id> ")
2079 << operands[0];
2080 }
2081
2082 uint32_t resultID = operands[1];
2083
2084 if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2085 return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2086 }
2087
2088 APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2089 Type i32Ty = opBuilder.getIntegerType(32);
2090 IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2091 graphConstantMap.try_emplace(
2092 resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2093
2094 return success();
2095}
2096
2097//===----------------------------------------------------------------------===//
2098// Control flow
2099//===----------------------------------------------------------------------===//
2100
2102 if (auto *block = getBlock(id)) {
2103 LLVM_DEBUG(logger.startLine() << "[block] got exiting block for id = " << id
2104 << " @ " << block << "\n");
2105 return block;
2106 }
2107
2108 // We don't know where this block will be placed finally (in a
2109 // spirv.mlir.selection or spirv.mlir.loop or function). Create it into the
2110 // function for now and sort out the proper place later.
2111 auto *block = curFunction->addBlock();
2112 LLVM_DEBUG(logger.startLine() << "[block] created block for id = " << id
2113 << " @ " << block << "\n");
2114 return blockMap[id] = block;
2115}
2116
2118 if (!curBlock) {
2119 return emitError(unknownLoc, "OpBranch must appear inside a block");
2120 }
2121
2122 if (operands.size() != 1) {
2123 return emitError(unknownLoc, "OpBranch must take exactly one target label");
2124 }
2125
2126 auto *target = getOrCreateBlock(operands[0]);
2127 auto loc = createFileLineColLoc(opBuilder);
2128 // The preceding instruction for the OpBranch instruction could be an
2129 // OpLoopMerge or an OpSelectionMerge instruction, in this case they will have
2130 // the same OpLine information.
2131 spirv::BranchOp::create(opBuilder, loc, target);
2132
2134 return success();
2135}
2136
2137LogicalResult
2139 if (!curBlock) {
2140 return emitError(unknownLoc,
2141 "OpBranchConditional must appear inside a block");
2142 }
2143
2144 if (operands.size() != 3 && operands.size() != 5) {
2145 return emitError(unknownLoc,
2146 "OpBranchConditional must have condition, true label, "
2147 "false label, and optionally two branch weights");
2148 }
2149
2150 auto condition = getValue(operands[0]);
2151 auto *trueBlock = getOrCreateBlock(operands[1]);
2152 auto *falseBlock = getOrCreateBlock(operands[2]);
2153
2154 std::optional<std::pair<uint32_t, uint32_t>> weights;
2155 if (operands.size() == 5) {
2156 weights = std::make_pair(operands[3], operands[4]);
2157 }
2158 // The preceding instruction for the OpBranchConditional instruction could be
2159 // an OpSelectionMerge instruction, in this case they will have the same
2160 // OpLine information.
2161 auto loc = createFileLineColLoc(opBuilder);
2162 spirv::BranchConditionalOp::create(
2163 opBuilder, loc, condition, trueBlock,
2164 /*trueArguments=*/ArrayRef<Value>(), falseBlock,
2165 /*falseArguments=*/ArrayRef<Value>(), weights);
2166
2168 return success();
2169}
2170
2172 if (!curFunction) {
2173 return emitError(unknownLoc, "OpLabel must appear inside a function");
2174 }
2175
2176 if (operands.size() != 1) {
2177 return emitError(unknownLoc, "OpLabel should only have result <id>");
2178 }
2179
2180 auto labelID = operands[0];
2181 // We may have forward declared this block.
2182 auto *block = getOrCreateBlock(labelID);
2183 LLVM_DEBUG(logger.startLine()
2184 << "[block] populating block " << block << "\n");
2185 // If we have seen this block, make sure it was just a forward declaration.
2186 assert(block->empty() && "re-deserialize the same block!");
2187
2188 opBuilder.setInsertionPointToStart(block);
2189 blockMap[labelID] = curBlock = block;
2190
2191 return success();
2192}
2193
2194LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2195 if (!curGraph) {
2196 return emitError(unknownLoc, "a graph block must appear inside a graph");
2197 }
2198
2199 // We may have forward declared this block.
2200 Block *block = getOrCreateBlock(graphID);
2201 LLVM_DEBUG(logger.startLine()
2202 << "[block] populating block " << block << "\n");
2203 // If we have seen this block, make sure it was just a forward declaration.
2204 assert(block->empty() && "re-deserialize the same block!");
2205
2206 opBuilder.setInsertionPointToStart(block);
2207 blockMap[graphID] = curBlock = block;
2208
2209 return success();
2210}
2211
2212LogicalResult
2214 if (!curBlock) {
2215 return emitError(unknownLoc, "OpSelectionMerge must appear in a block");
2216 }
2217
2218 if (operands.size() < 2) {
2219 return emitError(
2220 unknownLoc,
2221 "OpSelectionMerge must specify merge target and selection control");
2222 }
2223
2224 auto *mergeBlock = getOrCreateBlock(operands[0]);
2225 auto loc = createFileLineColLoc(opBuilder);
2226 auto selectionControl = operands[1];
2227
2228 if (!blockMergeInfo.try_emplace(curBlock, loc, selectionControl, mergeBlock)
2229 .second) {
2230 return emitError(
2231 unknownLoc,
2232 "a block cannot have more than one OpSelectionMerge instruction");
2233 }
2234
2235 return success();
2236}
2237
2238LogicalResult
2240 if (!curBlock) {
2241 return emitError(unknownLoc, "OpLoopMerge must appear in a block");
2242 }
2243
2244 if (operands.size() < 3) {
2245 return emitError(unknownLoc, "OpLoopMerge must specify merge target, "
2246 "continue target and loop control");
2247 }
2248
2249 auto *mergeBlock = getOrCreateBlock(operands[0]);
2250 auto *continueBlock = getOrCreateBlock(operands[1]);
2251 auto loc = createFileLineColLoc(opBuilder);
2252 uint32_t loopControl = operands[2];
2253
2254 if (!blockMergeInfo
2255 .try_emplace(curBlock, loc, loopControl, mergeBlock, continueBlock)
2256 .second) {
2257 return emitError(
2258 unknownLoc,
2259 "a block cannot have more than one OpLoopMerge instruction");
2260 }
2261
2262 return success();
2263}
2264
2266 if (!curBlock) {
2267 return emitError(unknownLoc, "OpPhi must appear in a block");
2268 }
2269
2270 if (operands.size() < 4) {
2271 return emitError(unknownLoc, "OpPhi must specify result type, result <id>, "
2272 "and variable-parent pairs");
2273 }
2274
2275 // Create a block argument for this OpPhi instruction.
2276 Type blockArgType = getType(operands[0]);
2277 BlockArgument blockArg = curBlock->addArgument(blockArgType, unknownLoc);
2278 valueMap[operands[1]] = blockArg;
2279 LLVM_DEBUG(logger.startLine()
2280 << "[phi] created block argument " << blockArg
2281 << " id = " << operands[1] << " of type " << blockArgType << "\n");
2282
2283 // For each (value, predecessor) pair, insert the value to the predecessor's
2284 // blockPhiInfo entry so later we can fix the block argument there.
2285 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2286 uint32_t value = operands[i];
2287 Block *predecessor = getOrCreateBlock(operands[i + 1]);
2288 std::pair<Block *, Block *> predecessorTargetPair{predecessor, curBlock};
2289 blockPhiInfo[predecessorTargetPair].push_back(value);
2290 LLVM_DEBUG(logger.startLine() << "[phi] predecessor @ " << predecessor
2291 << " with arg id = " << value << "\n");
2292 }
2293
2294 return success();
2295}
2296
2298 if (!curBlock)
2299 return emitError(unknownLoc, "OpSwitch must appear in a block");
2300
2301 if (operands.size() < 2)
2302 return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2303 "a default target");
2304
2305 if (operands.size() % 2)
2306 return emitError(unknownLoc,
2307 "OpSwitch must at have an even number of operands: "
2308 "selector, default target and any number of literal and "
2309 "label <id> pairs");
2310
2311 Value selector = getValue(operands[0]);
2312 Block *defaultBlock = getOrCreateBlock(operands[1]);
2313 Location loc = createFileLineColLoc(opBuilder);
2314
2315 SmallVector<int32_t> literals;
2316 SmallVector<Block *> blocks;
2317 for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2318 literals.push_back(operands[i]);
2319 blocks.push_back(getOrCreateBlock(operands[i + 1]));
2320 }
2321
2322 SmallVector<ValueRange> targetOperands(blocks.size(), {});
2323 spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2324 ArrayRef<Value>(), literals, blocks, targetOperands);
2325
2326 return success();
2327}
2328
2329namespace {
2330/// A class for putting all blocks in a structured selection/loop in a
2331/// spirv.mlir.selection/spirv.mlir.loop op.
2332class ControlFlowStructurizer {
2333public:
2334#ifndef NDEBUG
2335 ControlFlowStructurizer(Location loc, uint32_t control,
2336 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2337 Block *merge, Block *cont,
2338 llvm::ScopedPrinter &logger)
2339 : location(loc), control(control), blockMergeInfo(mergeInfo),
2340 headerBlock(header), mergeBlock(merge), continueBlock(cont),
2341 logger(logger) {}
2342#else
2343 ControlFlowStructurizer(Location loc, uint32_t control,
2344 spirv::BlockMergeInfoMap &mergeInfo, Block *header,
2345 Block *merge, Block *cont)
2346 : location(loc), control(control), blockMergeInfo(mergeInfo),
2347 headerBlock(header), mergeBlock(merge), continueBlock(cont) {}
2348#endif
2349
2350 /// Structurizes the loop at the given `headerBlock`.
2351 ///
2352 /// This method will create an spirv.mlir.loop op in the `mergeBlock` and move
2353 /// all blocks in the structured loop into the spirv.mlir.loop's region. All
2354 /// branches to the `headerBlock` will be redirected to the `mergeBlock`. This
2355 /// method will also update `mergeInfo` by remapping all blocks inside to the
2356 /// newly cloned ones inside structured control flow op's regions.
2357 LogicalResult structurize();
2358
2359private:
2360 /// Creates a new spirv.mlir.selection op at the beginning of the
2361 /// `mergeBlock`.
2362 spirv::SelectionOp createSelectionOp(uint32_t selectionControl);
2363
2364 /// Creates a new spirv.mlir.loop op at the beginning of the `mergeBlock`.
2365 spirv::LoopOp createLoopOp(uint32_t loopControl);
2366
2367 /// Collects all blocks reachable from `headerBlock` except `mergeBlock`.
2368 void collectBlocksInConstruct();
2369
2370 Location location;
2371 uint32_t control;
2372
2373 spirv::BlockMergeInfoMap &blockMergeInfo;
2374
2375 Block *headerBlock;
2376 Block *mergeBlock;
2377 Block *continueBlock; // nullptr for spirv.mlir.selection
2378
2379 SetVector<Block *> constructBlocks;
2380
2381#ifndef NDEBUG
2382 /// A logger used to emit information during the deserialzation process.
2383 llvm::ScopedPrinter &logger;
2384#endif
2385};
2386} // namespace
2387
2388spirv::SelectionOp
2389ControlFlowStructurizer::createSelectionOp(uint32_t selectionControl) {
2390 // Create a builder and set the insertion point to the beginning of the
2391 // merge block so that the newly created SelectionOp will be inserted there.
2392 OpBuilder builder(&mergeBlock->front());
2393
2394 auto control = static_cast<spirv::SelectionControl>(selectionControl);
2395 auto selectionOp = spirv::SelectionOp::create(builder, location, control);
2396 selectionOp.addMergeBlock(builder);
2397
2398 return selectionOp;
2399}
2400
2401spirv::LoopOp ControlFlowStructurizer::createLoopOp(uint32_t loopControl) {
2402 // Create a builder and set the insertion point to the beginning of the
2403 // merge block so that the newly created LoopOp will be inserted there.
2404 OpBuilder builder(&mergeBlock->front());
2405
2406 auto control = static_cast<spirv::LoopControl>(loopControl);
2407 auto loopOp = spirv::LoopOp::create(builder, location, control);
2408 loopOp.addEntryAndMergeBlock(builder);
2409
2410 return loopOp;
2411}
2412
2413void ControlFlowStructurizer::collectBlocksInConstruct() {
2414 assert(constructBlocks.empty() && "expected empty constructBlocks");
2415
2416 // Put the header block in the work list first.
2417 constructBlocks.insert(headerBlock);
2418
2419 // For each item in the work list, add its successors excluding the merge
2420 // block.
2421 for (unsigned i = 0; i < constructBlocks.size(); ++i) {
2422 for (auto *successor : constructBlocks[i]->getSuccessors())
2423 if (successor != mergeBlock)
2424 constructBlocks.insert(successor);
2425 }
2426}
2427
2428LogicalResult ControlFlowStructurizer::structurize() {
2429 Operation *op = nullptr;
2430 bool isLoop = continueBlock != nullptr;
2431 if (isLoop) {
2432 if (auto loopOp = createLoopOp(control))
2433 op = loopOp.getOperation();
2434 } else {
2435 if (auto selectionOp = createSelectionOp(control))
2436 op = selectionOp.getOperation();
2437 }
2438 if (!op)
2439 return failure();
2440 Region &body = op->getRegion(0);
2441
2442 IRMapping mapper;
2443 // All references to the old merge block should be directed to the
2444 // selection/loop merge block in the SelectionOp/LoopOp's region.
2445 mapper.map(mergeBlock, &body.back());
2446
2447 collectBlocksInConstruct();
2448
2449 // We've identified all blocks belonging to the selection/loop's region. Now
2450 // need to "move" them into the selection/loop. Instead of really moving the
2451 // blocks, in the following we copy them and remap all values and branches.
2452 // This is because:
2453 // * Inserting a block into a region requires the block not in any region
2454 // before. But selections/loops can nest so we can create selection/loop ops
2455 // in a nested manner, which means some blocks may already be in a
2456 // selection/loop region when to be moved again.
2457 // * It's much trickier to fix up the branches into and out of the loop's
2458 // region: we need to treat not-moved blocks and moved blocks differently:
2459 // Not-moved blocks jumping to the loop header block need to jump to the
2460 // merge point containing the new loop op but not the loop continue block's
2461 // back edge. Moved blocks jumping out of the loop need to jump to the
2462 // merge block inside the loop region but not other not-moved blocks.
2463 // We cannot use replaceAllUsesWith clearly and it's harder to follow the
2464 // logic.
2465
2466 // Create a corresponding block in the SelectionOp/LoopOp's region for each
2467 // block in this loop construct.
2468 OpBuilder builder(body);
2469 for (auto *block : constructBlocks) {
2470 // Create a block and insert it before the selection/loop merge block in the
2471 // SelectionOp/LoopOp's region.
2472 auto *newBlock = builder.createBlock(&body.back());
2473 mapper.map(block, newBlock);
2474 LLVM_DEBUG(logger.startLine() << "[cf] cloned block " << newBlock
2475 << " from block " << block << "\n");
2476 if (!isFnEntryBlock(block)) {
2477 for (BlockArgument blockArg : block->getArguments()) {
2478 auto newArg =
2479 newBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2480 mapper.map(blockArg, newArg);
2481 LLVM_DEBUG(logger.startLine() << "[cf] remapped block argument "
2482 << blockArg << " to " << newArg << "\n");
2483 }
2484 } else {
2485 LLVM_DEBUG(logger.startLine()
2486 << "[cf] block " << block << " is a function entry block\n");
2487 }
2488
2489 for (auto &op : *block)
2490 newBlock->push_back(op.clone(mapper));
2491 }
2492
2493 // Go through all ops and remap the operands.
2494 auto remapOperands = [&](Operation *op) {
2495 for (auto &operand : op->getOpOperands())
2496 if (Value mappedOp = mapper.lookupOrNull(operand.get()))
2497 operand.set(mappedOp);
2498 for (auto &succOp : op->getBlockOperands())
2499 if (Block *mappedOp = mapper.lookupOrNull(succOp.get()))
2500 succOp.set(mappedOp);
2501 };
2502 for (auto &block : body)
2503 block.walk(remapOperands);
2504
2505 // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to
2506 // the selection/loop construct into its region. Next we need to fix the
2507 // connections between this new SelectionOp/LoopOp with existing blocks.
2508
2509 // All existing incoming branches should go to the merge block, where the
2510 // SelectionOp/LoopOp resides right now.
2511 headerBlock->replaceAllUsesWith(mergeBlock);
2512
2513 LLVM_DEBUG({
2514 logger.startLine() << "[cf] after cloning and fixing references:\n";
2515 headerBlock->getParentOp()->print(logger.getOStream());
2516 logger.startLine() << "\n";
2517 });
2518
2519 if (isLoop) {
2520 if (!mergeBlock->args_empty()) {
2521 return mergeBlock->getParentOp()->emitError(
2522 "OpPhi in loop merge block unsupported");
2523 }
2524
2525 // The loop header block may have block arguments. Since now we place the
2526 // loop op inside the old merge block, we need to make sure the old merge
2527 // block has the same block argument list.
2528 for (BlockArgument blockArg : headerBlock->getArguments())
2529 mergeBlock->addArgument(blockArg.getType(), blockArg.getLoc());
2530
2531 // If the loop header block has block arguments, make sure the spirv.Branch
2532 // op matches.
2533 SmallVector<Value, 4> blockArgs;
2534 if (!headerBlock->args_empty())
2535 blockArgs = {mergeBlock->args_begin(), mergeBlock->args_end()};
2536
2537 // The loop entry block should have a unconditional branch jumping to the
2538 // loop header block.
2539 builder.setInsertionPointToEnd(&body.front());
2540 spirv::BranchOp::create(builder, location, mapper.lookupOrNull(headerBlock),
2541 ArrayRef<Value>(blockArgs));
2542 }
2543
2544 // Values defined inside the selection region that need to be yielded outside
2545 // the region.
2546 SmallVector<Value> valuesToYield;
2547 // Outside uses of values that were sunk into the selection region. Those uses
2548 // will be replaced with values returned by the SelectionOp.
2549 SmallVector<Value> outsideUses;
2550
2551 // Move block arguments of the original block (`mergeBlock`) into the merge
2552 // block inside the selection (`body.back()`). Values produced by block
2553 // arguments will be yielded by the selection region. We do not update uses or
2554 // erase original block arguments yet. It will be done later in the code.
2555 //
2556 // Code below is not executed for loops as it would interfere with the logic
2557 // above. Currently block arguments in the merge block are not supported, but
2558 // instead, the code above copies those arguments from the header block into
2559 // the merge block. As such, running the code would yield those copied
2560 // arguments that is most likely not a desired behaviour. This may need to be
2561 // revisited in the future.
2562 if (!isLoop)
2563 for (BlockArgument blockArg : mergeBlock->getArguments()) {
2564 // Create new block arguments in the last block ("merge block") of the
2565 // selection region. We create one argument for each argument in
2566 // `mergeBlock`. This new value will need to be yielded, and the original
2567 // value replaced, so add them to appropriate vectors.
2568 body.back().addArgument(blockArg.getType(), blockArg.getLoc());
2569 valuesToYield.push_back(body.back().getArguments().back());
2570 outsideUses.push_back(blockArg);
2571 }
2572
2573 // All the blocks cloned into the SelectionOp/LoopOp's region can now be
2574 // cleaned up.
2575 LLVM_DEBUG(logger.startLine() << "[cf] cleaning up blocks after clone\n");
2576 // First we need to drop all operands' references inside all blocks. This is
2577 // needed because we can have blocks referencing SSA values from one another.
2578 for (auto *block : constructBlocks)
2579 block->dropAllReferences();
2580
2581 // All internal uses should be removed from original blocks by now, so
2582 // whatever is left is an outside use and will need to be yielded from
2583 // the newly created selection / loop region.
2584 for (Block *block : constructBlocks) {
2585 for (Operation &op : *block) {
2586 if (!op.use_empty())
2587 for (Value result : op.getResults()) {
2588 valuesToYield.push_back(mapper.lookupOrNull(result));
2589 outsideUses.push_back(result);
2590 }
2591 }
2592 for (BlockArgument &arg : block->getArguments()) {
2593 if (!arg.use_empty()) {
2594 valuesToYield.push_back(mapper.lookupOrNull(arg));
2595 outsideUses.push_back(arg);
2596 }
2597 }
2598 }
2599
2600 assert(valuesToYield.size() == outsideUses.size());
2601
2602 // If we need to yield any values from the selection / loop region we will
2603 // take care of it here.
2604 if (!valuesToYield.empty()) {
2605 LLVM_DEBUG(logger.startLine()
2606 << "[cf] yielding values from the selection / loop region\n");
2607
2608 // Update `mlir.merge` with values to be yield.
2609 auto mergeOps = body.back().getOps<spirv::MergeOp>();
2610 Operation *merge = llvm::getSingleElement(mergeOps);
2611 assert(merge);
2612 merge->setOperands(valuesToYield);
2613
2614 // MLIR does not allow changing the number of results of an operation, so
2615 // we create a new SelectionOp / LoopOp with required list of results and
2616 // move the region from the initial SelectionOp / LoopOp. The initial
2617 // operation is then removed. Since we move the region to the new op all
2618 // links between blocks and remapping we have previously done should be
2619 // preserved.
2620 builder.setInsertionPoint(&mergeBlock->front());
2621
2622 Operation *newOp = nullptr;
2623
2624 if (isLoop)
2625 newOp = spirv::LoopOp::create(builder, location,
2626 TypeRange(ValueRange(outsideUses)),
2627 static_cast<spirv::LoopControl>(control));
2628 else
2629 newOp = spirv::SelectionOp::create(
2630 builder, location, TypeRange(ValueRange(outsideUses)),
2631 static_cast<spirv::SelectionControl>(control));
2632
2633 newOp->getRegion(0).takeBody(body);
2634
2635 // Remove initial op and swap the pointer to the newly created one.
2636 op->erase();
2637 op = newOp;
2638
2639 // Update all outside uses to use results of the SelectionOp / LoopOp and
2640 // remove block arguments from the original merge block.
2641 for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
2642 outsideUses[i].replaceAllUsesWith(op->getResult(i));
2643
2644 // We do not support block arguments in loop merge block. Also running this
2645 // function with loop would break some of the loop specific code above
2646 // dealing with block arguments.
2647 if (!isLoop)
2648 mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
2649 }
2650
2651 // Check that whether some op in the to-be-erased blocks still has uses. Those
2652 // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's
2653 // region. We cannot handle such cases given that once a value is sinked into
2654 // the SelectionOp/LoopOp's region, there is no escape for it.
2655 for (auto *block : constructBlocks) {
2656 if (!block->use_empty())
2657 return emitError(block->getParent()->getLoc(),
2658 "failed control flow structurization: "
2659 "block has uses outside of the "
2660 "enclosing selection/loop construct");
2661 for (Operation &op : *block)
2662 if (!op.use_empty())
2663 return op.emitOpError("failed control flow structurization: value has "
2664 "uses outside of the "
2665 "enclosing selection/loop construct");
2666 for (BlockArgument &arg : block->getArguments())
2667 if (!arg.use_empty())
2668 return emitError(arg.getLoc(), "failed control flow structurization: "
2669 "block argument has uses outside of the "
2670 "enclosing selection/loop construct");
2671 }
2672
2673 // Then erase all old blocks.
2674 for (auto *block : constructBlocks) {
2675 // We've cloned all blocks belonging to this construct into the structured
2676 // control flow op's region. Among these blocks, some may compose another
2677 // selection/loop. If so, they will be recorded within blockMergeInfo.
2678 // We need to update the pointers there to the newly remapped ones so we can
2679 // continue structurizing them later.
2680 //
2681 // We need to walk each block as constructBlocks do not include blocks
2682 // internal to ops already structured within those blocks. It is not
2683 // fully clear to me why the mergeInfo of blocks (yet to be structured)
2684 // inside already structured selections/loops get invalidated and needs
2685 // updating, however the following example code can cause a crash (depending
2686 // on the structuring order), when the most inner selection is being
2687 // structured after the outer selection and loop have been already
2688 // structured:
2689 //
2690 // spirv.mlir.for {
2691 // // ...
2692 // spirv.mlir.selection {
2693 // // ..
2694 // // A selection region that hasn't been yet structured!
2695 // // ..
2696 // }
2697 // // ...
2698 // }
2699 //
2700 // If the loop gets structured after the outer selection, but before the
2701 // inner selection. Moving the already structured selection inside the loop
2702 // will invalidate the mergeInfo of the region that is not yet structured.
2703 // Just going over constructBlocks will not check and updated header blocks
2704 // inside the already structured selection region. Walking block fixes that.
2705 //
2706 // TODO: If structuring was done in a fixed order starting with inner
2707 // most constructs this most likely not be an issue and the whole code
2708 // section could be removed. However, with the current non-deterministic
2709 // order this is not possible.
2710 //
2711 // TODO: The asserts in the following assumes input SPIR-V blob forms
2712 // correctly nested selection/loop constructs. We should relax this and
2713 // support error cases better.
2714 auto updateMergeInfo = [&](Block *block) -> WalkResult {
2715 auto it = blockMergeInfo.find(block);
2716 if (it != blockMergeInfo.end()) {
2717 // Use the original location for nested selection/loop ops.
2718 Location loc = it->second.loc;
2719
2720 Block *newHeader = mapper.lookupOrNull(block);
2721 if (!newHeader)
2722 return emitError(loc, "failed control flow structurization: nested "
2723 "loop header block should be remapped!");
2724
2725 Block *newContinue = it->second.continueBlock;
2726 if (newContinue) {
2727 newContinue = mapper.lookupOrNull(newContinue);
2728 if (!newContinue)
2729 return emitError(loc, "failed control flow structurization: nested "
2730 "loop continue block should be remapped!");
2731 }
2732
2733 Block *newMerge = it->second.mergeBlock;
2734 if (Block *mappedTo = mapper.lookupOrNull(newMerge))
2735 newMerge = mappedTo;
2736
2737 // The iterator should be erased before adding a new entry into
2738 // blockMergeInfo to avoid iterator invalidation.
2739 blockMergeInfo.erase(it);
2740 blockMergeInfo.try_emplace(newHeader, loc, it->second.control, newMerge,
2741 newContinue);
2742 }
2743
2744 return WalkResult::advance();
2745 };
2746
2747 if (block->walk(updateMergeInfo).wasInterrupted())
2748 return failure();
2749
2750 // The structured selection/loop's entry block does not have arguments.
2751 // If the function's header block is also part of the structured control
2752 // flow, we cannot just simply erase it because it may contain arguments
2753 // matching the function signature and used by the cloned blocks.
2754 if (isFnEntryBlock(block)) {
2755 LLVM_DEBUG(logger.startLine() << "[cf] changing entry block " << block
2756 << " to only contain a spirv.Branch op\n");
2757 // Still keep the function entry block for the potential block arguments,
2758 // but replace all ops inside with a branch to the merge block.
2759 block->clear();
2760 builder.setInsertionPointToEnd(block);
2761 spirv::BranchOp::create(builder, location, mergeBlock);
2762 } else {
2763 LLVM_DEBUG(logger.startLine() << "[cf] erasing block " << block << "\n");
2764 block->erase();
2765 }
2766 }
2767
2768 LLVM_DEBUG(logger.startLine()
2769 << "[cf] after structurizing construct with header block "
2770 << headerBlock << ":\n"
2771 << *op << "\n");
2772
2773 return success();
2774}
2775
2777 LLVM_DEBUG({
2778 logger.startLine()
2779 << "//----- [phi] start wiring up block arguments -----//\n";
2780 logger.indent();
2781 });
2782
2783 OpBuilder::InsertionGuard guard(opBuilder);
2784
2785 for (const auto &info : blockPhiInfo) {
2786 Block *block = info.first.first;
2787 Block *target = info.first.second;
2788 const BlockPhiInfo &phiInfo = info.second;
2789 LLVM_DEBUG({
2790 logger.startLine() << "[phi] block " << block << "\n";
2791 logger.startLine() << "[phi] before creating block argument:\n";
2792 block->getParentOp()->print(logger.getOStream());
2793 logger.startLine() << "\n";
2794 });
2795
2796 // Set insertion point to before this block's terminator early because we
2797 // may materialize ops via getValue() call.
2798 auto *op = block->getTerminator();
2799 opBuilder.setInsertionPoint(op);
2800
2801 SmallVector<Value, 4> blockArgs;
2802 blockArgs.reserve(phiInfo.size());
2803 for (uint32_t valueId : phiInfo) {
2804 if (Value value = getValue(valueId)) {
2805 blockArgs.push_back(value);
2806 LLVM_DEBUG(logger.startLine() << "[phi] block argument " << value
2807 << " id = " << valueId << "\n");
2808 } else {
2809 return emitError(unknownLoc, "OpPhi references undefined value!");
2810 }
2811 }
2812
2813 if (auto branchOp = dyn_cast<spirv::BranchOp>(op)) {
2814 // Replace the previous branch op with a new one with block arguments.
2815 spirv::BranchOp::create(opBuilder, branchOp.getLoc(),
2816 branchOp.getTarget(), blockArgs);
2817 branchOp.erase();
2818 } else if (auto branchCondOp = dyn_cast<spirv::BranchConditionalOp>(op)) {
2819 assert((branchCondOp.getTrueBlock() == target ||
2820 branchCondOp.getFalseBlock() == target) &&
2821 "expected target to be either the true or false target");
2822 if (target == branchCondOp.getTrueTarget())
2823 spirv::BranchConditionalOp::create(
2824 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2825 blockArgs, branchCondOp.getFalseBlockArguments(),
2826 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueTarget(),
2827 branchCondOp.getFalseTarget());
2828 else
2829 spirv::BranchConditionalOp::create(
2830 opBuilder, branchCondOp.getLoc(), branchCondOp.getCondition(),
2831 branchCondOp.getTrueBlockArguments(), blockArgs,
2832 branchCondOp.getBranchWeightsAttr(), branchCondOp.getTrueBlock(),
2833 branchCondOp.getFalseBlock());
2834
2835 branchCondOp.erase();
2836 } else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2837 if (target == switchOp.getDefaultTarget()) {
2838 SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
2839 DenseIntElementsAttr literals =
2840 switchOp.getLiterals().value_or(DenseIntElementsAttr());
2841 spirv::SwitchOp::create(
2842 opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2843 switchOp.getDefaultTarget(), blockArgs, literals,
2844 switchOp.getTargets(), targetOperands);
2845 switchOp.erase();
2846 } else {
2847 SuccessorRange targets = switchOp.getTargets();
2848 auto it = llvm::find(targets, target);
2849 assert(it != targets.end());
2850 size_t index = std::distance(targets.begin(), it);
2851 switchOp.getTargetOperandsMutable(index).assign(blockArgs);
2852 }
2853 } else {
2854 return emitError(unknownLoc, "unimplemented terminator for Phi creation");
2855 }
2856
2857 LLVM_DEBUG({
2858 logger.startLine() << "[phi] after creating block argument:\n";
2859 block->getParentOp()->print(logger.getOStream());
2860 logger.startLine() << "\n";
2861 });
2862 }
2863 blockPhiInfo.clear();
2864
2865 LLVM_DEBUG({
2866 logger.unindent();
2867 logger.startLine()
2868 << "//--- [phi] completed wiring up block arguments ---//\n";
2869 });
2870 return success();
2871}
2872
2874 // Create a copy, so we can modify keys in the original.
2875 BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2876 for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2877 it != e; ++it) {
2878 auto &[block, mergeInfo] = *it;
2879
2880 // Skip processing loop regions. For loop regions continueBlock is non-null.
2881 if (mergeInfo.continueBlock)
2882 continue;
2883
2884 if (!block->mightHaveTerminator())
2885 continue;
2886
2887 Operation *terminator = block->getTerminator();
2888 assert(terminator);
2889
2890 if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
2891 continue;
2892
2893 // Check if the current header block is a merge block of another construct.
2894 bool splitHeaderMergeBlock = false;
2895 for (const auto &[_, mergeInfo] : blockMergeInfo) {
2896 if (mergeInfo.mergeBlock == block)
2897 splitHeaderMergeBlock = true;
2898 }
2899
2900 // Do not split a block that only contains a conditional branch / switch,
2901 // unless it is also a merge block of another construct - in that case we
2902 // want to split the block. We do not want two constructs to share header /
2903 // merge block.
2904 if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
2905 Block *newBlock = block->splitBlock(terminator);
2906 OpBuilder builder(block, block->end());
2907 spirv::BranchOp::create(builder, block->getParent()->getLoc(), newBlock);
2908
2909 // After splitting we need to update the map to use the new block as a
2910 // header.
2911 blockMergeInfo.erase(block);
2912 blockMergeInfo.try_emplace(newBlock, mergeInfo);
2913 }
2914 }
2915
2916 return success();
2917}
2918
2920 if (!options.enableControlFlowStructurization) {
2921 LLVM_DEBUG(
2922 {
2923 logger.startLine()
2924 << "//----- [cf] skip structurizing control flow -----//\n";
2925 logger.indent();
2926 });
2927 return success();
2928 }
2929
2930 LLVM_DEBUG({
2931 logger.startLine()
2932 << "//----- [cf] start structurizing control flow -----//\n";
2933 logger.indent();
2934 });
2935
2936 LLVM_DEBUG({
2937 logger.startLine() << "[cf] split conditional blocks\n";
2938 logger.startLine() << "\n";
2939 });
2940
2941 if (failed(splitSelectionHeader())) {
2942 return failure();
2943 }
2944
2945 while (!blockMergeInfo.empty()) {
2946 Block *headerBlock = blockMergeInfo.begin()->first;
2947 BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;
2948
2949 LLVM_DEBUG({
2950 logger.startLine() << "[cf] header block " << headerBlock << ":\n";
2951 headerBlock->print(logger.getOStream());
2952 logger.startLine() << "\n";
2953 });
2954
2955 auto *mergeBlock = mergeInfo.mergeBlock;
2956 assert(mergeBlock && "merge block cannot be nullptr");
2957 if (mergeInfo.continueBlock && !mergeBlock->args_empty())
2958 return emitError(unknownLoc, "OpPhi in loop merge block unimplemented");
2959 LLVM_DEBUG({
2960 logger.startLine() << "[cf] merge block " << mergeBlock << ":\n";
2961 mergeBlock->print(logger.getOStream());
2962 logger.startLine() << "\n";
2963 });
2964
2965 auto *continueBlock = mergeInfo.continueBlock;
2966 LLVM_DEBUG(if (continueBlock) {
2967 logger.startLine() << "[cf] continue block " << continueBlock << ":\n";
2968 continueBlock->print(logger.getOStream());
2969 logger.startLine() << "\n";
2970 });
2971 // Erase this case before calling into structurizer, who will update
2972 // blockMergeInfo.
2973 blockMergeInfo.erase(blockMergeInfo.begin());
2974 ControlFlowStructurizer structurizer(mergeInfo.loc, mergeInfo.control,
2975 blockMergeInfo, headerBlock,
2976 mergeBlock, continueBlock
2977#ifndef NDEBUG
2978 ,
2979 logger
2980#endif
2981 );
2982 if (failed(structurizer.structurize()))
2983 return failure();
2984 }
2985
2986 LLVM_DEBUG({
2987 logger.unindent();
2988 logger.startLine()
2989 << "//--- [cf] completed structurizing control flow ---//\n";
2990 });
2991 return success();
2992}
2993
2994//===----------------------------------------------------------------------===//
2995// Debug
2996//===----------------------------------------------------------------------===//
2997
2999 if (!debugLine)
3000 return unknownLoc;
3001
3002 auto fileName = debugInfoMap.lookup(debugLine->fileID).str();
3003 if (fileName.empty())
3004 fileName = "<unknown>";
3005 return FileLineColLoc::get(opBuilder.getStringAttr(fileName), debugLine->line,
3006 debugLine->column);
3007}
3008
3009LogicalResult
3011 // According to SPIR-V spec:
3012 // "This location information applies to the instructions physically
3013 // following this instruction, up to the first occurrence of any of the
3014 // following: the next end of block, the next OpLine instruction, or the next
3015 // OpNoLine instruction."
3016 if (operands.size() != 3)
3017 return emitError(unknownLoc, "OpLine must have 3 operands");
3018 debugLine = DebugLine{operands[0], operands[1], operands[2]};
3019 return success();
3020}
3021
3022void spirv::Deserializer::clearDebugLine() { debugLine = std::nullopt; }
3023
3024LogicalResult
3026 if (operands.size() < 2)
3027 return emitError(unknownLoc, "OpString needs at least 2 operands");
3028
3029 if (!debugInfoMap.lookup(operands[0]).empty())
3030 return emitError(unknownLoc,
3031 "duplicate debug string found for result <id> ")
3032 << operands[0];
3033
3034 unsigned wordIndex = 1;
3035 StringRef debugString = decodeStringLiteral(operands, wordIndex);
3036 if (wordIndex != operands.size())
3037 return emitError(unknownLoc,
3038 "unexpected trailing words in OpString instruction");
3039
3040 debugInfoMap[operands[0]] = debugString;
3041 return success();
3042}
return success()
static bool isLoop(Operation *op)
Returns true if the given operation represents a loop by testing whether it implements the LoopLikeOp...
static bool isFnEntryBlock(Block *block)
Returns true if the given block is a function entry block.
#define MIN_VERSION_CASE(v)
static LogicalResult deserializeCacheControlDecoration(Location loc, OpBuilder &opBuilder, DenseMap< uint32_t, NamedAttrList > &decorations, ArrayRef< uint32_t > words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind)
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
void erase()
Unlink this Block from its parent region and delete it.
Definition Block.cpp:66
Block * splitBlock(iterator splitBefore)
Split the block into two blocks before the specified operation or iterator.
Definition Block.cpp:323
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
void print(raw_ostream &os)
bool args_empty()
Definition Block.h:109
iterator begin()
Definition Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition Block.cpp:36
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static FileLineColLoc get(StringAttr filename, unsigned line, unsigned column)
Definition Location.cpp:157
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
auto lookupOrNull(T from) const
Lookup a mapped value within the map.
Definition IRMapping.h:58
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MutableArrayRef< BlockOperand > getBlockOperands()
Definition Operation.h:695
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool use_empty()
Returns true if this operation has no uses.
Definition Operation.h:852
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
void print(raw_ostream &os, const OpPrintingFlags &flags={})
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
void erase()
Remove this operation from its parent block and delete it.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & back()
Definition Region.h:64
iterator end()
Definition Region.h:56
BlockListType & getBlocks()
Definition Region.h:45
BlockListType::iterator iterator
Definition Region.h:52
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
Definition Region.h:241
This class implements the successor iterators for Block.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
static WalkResult advance()
Definition WalkResult.h:47
static ArrayType get(Type elementType, unsigned elementCount)
static CooperativeMatrixType get(Type elementType, uint32_t rows, uint32_t columns, Scope scope, CooperativeMatrixUseKHR use)
LogicalResult wireUpBlockArgument()
Creates block arguments on predecessors previously recorded when handling OpPhi instructions.
Value materializeSpecConstantOperation(uint32_t resultID, spirv::Opcode enclosedOpcode, uint32_t resultTypeID, ArrayRef< uint32_t > enclosedOpOperands)
Materializes/emits an OpSpecConstantOp instruction.
LogicalResult processOpTypePointer(ArrayRef< uint32_t > operands)
Value getValue(uint32_t id)
Get the Value associated with a result <id>.
LogicalResult processMatrixType(ArrayRef< uint32_t > operands)
LogicalResult processGlobalVariable(ArrayRef< uint32_t > operands)
Processes the OpVariable instructions at current offset into binary.
std::optional< SpecConstOperationMaterializationInfo > getSpecConstantOperation(uint32_t id)
Gets the info needed to materialize the spec constant operation op associated with the given <id>.
LogicalResult processConstantNull(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantNull instruction with the given operands.
LogicalResult processSpecConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantComposite instruction with the given operands.
LogicalResult processInstruction(spirv::Opcode opcode, ArrayRef< uint32_t > operands, bool deferInstructions=true)
Processes a SPIR-V instruction with the given opcode and operands.
LogicalResult processBranchConditional(ArrayRef< uint32_t > operands)
spirv::GlobalVariableOp getGlobalVariable(uint32_t id)
Gets the global variable associated with a result <id> of OpVariable.
LogicalResult createGraphBlock(uint32_t graphID)
Creates a block for graph with the given graphID.
LogicalResult processStructType(ArrayRef< uint32_t > operands)
LogicalResult processGraphARM(ArrayRef< uint32_t > operands)
LogicalResult setFunctionArgAttrs(uint32_t argID, SmallVectorImpl< Attribute > &argAttrs, size_t argIndex)
Sets the function argument's attributes.
LogicalResult structurizeControlFlow()
Extracts blocks belonging to a structured selection/loop into a spirv.mlir.selection/spirv....
LogicalResult processLabel(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLabel instruction with the given operands.
LogicalResult processSampledImageType(ArrayRef< uint32_t > operands)
LogicalResult processTensorARMType(ArrayRef< uint32_t > operands)
std::optional< spirv::GraphConstantARMOpMaterializationInfo > getGraphConstantARM(uint32_t id)
Gets the GraphConstantARM ID attribute and result type with the given result <id>.
std::optional< std::pair< Attribute, Type > > getConstant(uint32_t id)
Gets the constant's attribute and type associated with the given <id>.
LogicalResult processType(spirv::Opcode opcode, ArrayRef< uint32_t > operands)
Processes a SPIR-V type instruction with given opcode and operands and registers the type into module...
LogicalResult processLoopMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLoopMerge instruction with the given operands.
LogicalResult processArrayType(ArrayRef< uint32_t > operands)
LogicalResult sliceInstruction(spirv::Opcode &opcode, ArrayRef< uint32_t > &operands, std::optional< spirv::Opcode > expectedOpcode=std::nullopt)
Slices the first instruction out of binary and returns its opcode and operands via opcode and operand...
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id)
Gets the composite specialization constant with the given result <id>.
SmallVector< uint32_t, 2 > BlockPhiInfo
For OpPhi instructions, we use block arguments to represent them.
LogicalResult processSpecConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processCooperativeMatrixTypeKHR(ArrayRef< uint32_t > operands)
LogicalResult processGraphEntryPointARM(ArrayRef< uint32_t > operands)
LogicalResult processFunction(ArrayRef< uint32_t > operands)
Creates a deserializer for the given SPIR-V binary module.
StringAttr getSymbolDecoration(StringRef decorationName)
Gets the symbol name from the name of decoration.
Block * getOrCreateBlock(uint32_t id)
Gets or creates the block corresponding to the given label <id>.
bool isVoidType(Type type) const
Returns true if the given type is for SPIR-V void type.
std::string getSpecConstantSymbol(uint32_t id)
Returns a symbol to be used for the specialization constant with the given result <id>.
LogicalResult processDebugString(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpString instruction with the given operands.
LogicalResult processPhi(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpPhi instruction with the given operands.
std::string getFunctionSymbol(uint32_t id)
Returns a symbol to be used for the function name with the given result <id>.
void clearDebugLine()
Discontinues any source-level location information that might be active from a previous OpLine instru...
LogicalResult processFunctionType(ArrayRef< uint32_t > operands)
IntegerAttr getConstantInt(uint32_t id)
Gets the constant's integer attribute with the given <id>.
LogicalResult processTypeForwardPointer(ArrayRef< uint32_t > operands)
LogicalResult processSwitch(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSwitch instruction with the given operands.
LogicalResult processGraphEndARM(ArrayRef< uint32_t > operands)
LogicalResult processImageType(ArrayRef< uint32_t > operands)
LogicalResult processConstantComposite(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantComposite instruction with the given operands.
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID, TypedAttr defaultValue)
Creates a spirv::SpecConstantOp.
Block * getBlock(uint32_t id) const
Returns the block for the given label <id>.
LogicalResult processGraphTypeARM(ArrayRef< uint32_t > operands)
LogicalResult processBranch(ArrayRef< uint32_t > operands)
std::optional< std::pair< Attribute, Type > > getConstantCompositeReplicate(uint32_t id)
Gets the replicated composite constant's attribute and type associated with the given <id>.
LogicalResult processFunctionEnd(ArrayRef< uint32_t > operands)
Processes OpFunctionEnd and finalizes function.
LogicalResult processRuntimeArrayType(ArrayRef< uint32_t > operands)
LogicalResult processSpecConstantOperation(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSpecConstantOp instruction with the given operands.
LogicalResult processConstant(ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant instruction with the given operands.
Location createFileLineColLoc(OpBuilder opBuilder)
Creates a FileLineColLoc with the OpLine location information.
LogicalResult processGraphConstantARM(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpGraphConstantARM instruction with the given operands.
LogicalResult processConstantBool(bool isTrue, ArrayRef< uint32_t > operands, bool isSpec)
Processes a SPIR-V Op{|Spec}Constant{True|False} instruction with the given operands.
spirv::SpecConstantOp getSpecConstant(uint32_t id)
Gets the specialization constant with the given result <id>.
LogicalResult processConstantCompositeReplicateEXT(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpConstantCompositeReplicateEXT instruction with the given operands.
LogicalResult processSelectionMerge(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpSelectionMerge instruction with the given operands.
LogicalResult processOpGraphSetOutputARM(ArrayRef< uint32_t > operands)
LogicalResult processDebugLine(ArrayRef< uint32_t > operands)
Processes a SPIR-V OpLine instruction with the given operands.
LogicalResult splitSelectionHeader()
Move a conditional branch or a switch into a separate basic block to avoid unnecessary sinking of def...
std::string getGraphSymbol(uint32_t id)
Returns a symbol to be used for the graph name with the given result <id>.
static ImageType get(Type elementType, Dim dim, ImageDepthInfo depth=ImageDepthInfo::DepthUnknown, ImageArrayedInfo arrayed=ImageArrayedInfo::NonArrayed, ImageSamplingInfo samplingInfo=ImageSamplingInfo::SingleSampled, ImageSamplerUseInfo samplerUse=ImageSamplerUseInfo::SamplerUnknown, ImageFormat format=ImageFormat::Unknown)
Definition SPIRVTypes.h:147
static MatrixType get(Type columnType, uint32_t columnCount)
static PointerType get(Type pointeeType, StorageClass storageClass)
static RuntimeArrayType get(Type elementType)
static SampledImageType get(Type imageType)
static StructType getIdentified(MLIRContext *context, StringRef identifier)
Construct an identified StructType.
static StructType getEmpty(MLIRContext *context, StringRef identifier="")
Construct a (possibly identified) StructType with no members.
static StructType get(ArrayRef< Type > memberTypes, ArrayRef< OffsetInfo > offsetInfo={}, ArrayRef< MemberDecorationInfo > memberDecorations={}, ArrayRef< StructDecorationInfo > structDecorations={})
Construct a literal StructType with at least one member.
static TensorArmType get(ArrayRef< int64_t > shape, Type elementType)
static VerCapExtAttr get(Version version, ArrayRef< Capability > capabilities, ArrayRef< Extension > extensions, MLIRContext *context)
Gets a VerCapExtAttr instance.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
constexpr uint32_t kMagicNumber
SPIR-V magic number.
llvm::MapVector< Block *, BlockMergeInfo > BlockMergeInfoMap
Map from a selection/loop's header block to its merge (and continue) target.
StringRef decodeStringLiteral(ArrayRef< uint32_t > words, unsigned &wordIndex)
Decodes a string literal in words starting at wordIndex.
constexpr unsigned kHeaderWordCount
SPIR-V binary header word count.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static std::string debugString(T &&op)
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
A struct for containing a header block's merge and continue targets.
A struct for containing OpLine instruction information.
A struct that collects the info needed to materialize/emit a GraphConstantARMOp.
A struct that collects the info needed to materialize/emit a SpecConstantOperation op.