MLIR 22.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1//===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the serialization methods for MLIR SPIR-V module ops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Serializer.h"
14
19#include "llvm/ADT/DepthFirstIterator.h"
20#include "llvm/ADT/StringExtras.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "spirv-serialization"
24
25using namespace mlir;
26
27/// A pre-order depth-first visitor function for processing basic blocks.
28///
29/// Visits the basic blocks starting from the given `headerBlock` in pre-order
30/// depth-first manner and calls `blockHandler` on each block. Skips handling
31/// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32/// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33/// successors.
34///
35/// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36/// of blocks in a function must satisfy the rule that blocks appear before
37/// all blocks they dominate." This can be achieved by a pre-order CFG
38/// traversal algorithm. To make the serialization output more logical and
39/// readable to human, we perform depth-first CFG traversal and delay the
40/// serialization of the merge block and the continue block, if exists, until
41/// after all other blocks have been processed.
42static LogicalResult
44 function_ref<LogicalResult(Block *)> blockHandler,
45 bool skipHeader = false, BlockRange skipBlocks = {}) {
46 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48
49 for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50 if (skipHeader && block == headerBlock)
51 continue;
52 if (failed(blockHandler(block)))
53 return failure();
54 }
55 return success();
56}
57
58namespace mlir {
59namespace spirv {
60LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61 if (auto resultID =
62 prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63 valueIDMap[op.getResult()] = resultID;
64 return success();
65 }
66 return failure();
67}
68
69LogicalResult Serializer::processConstantCompositeReplicateOp(
70 spirv::EXTConstantCompositeReplicateOp op) {
71 if (uint32_t resultID = prepareConstantCompositeReplicate(
72 op.getLoc(), op.getType(), op.getValue())) {
73 valueIDMap[op.getResult()] = resultID;
74 return success();
75 }
76 return failure();
77}
78
79LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
80 if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
81 /*isSpec=*/true)) {
82 // Emit the OpDecorate instruction for SpecId.
83 if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
84 auto val = static_cast<uint32_t>(specID.getInt());
85 if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
86 return failure();
87 }
88
89 specConstIDMap[op.getSymName()] = resultID;
90 return processName(resultID, op.getSymName());
91 }
92 return failure();
93}
94
95LogicalResult
96Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
97 uint32_t typeID = 0;
98 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
99 return failure();
100 }
101
102 auto resultID = getNextID();
103
105 operands.push_back(typeID);
106 operands.push_back(resultID);
107
108 auto constituents = op.getConstituents();
109
110 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
111 auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
112
113 auto constituentName = constituent.getValue();
114 auto constituentID = getSpecConstID(constituentName);
115
116 if (!constituentID) {
117 return op.emitError("unknown result <id> for specialization constant ")
118 << constituentName;
119 }
120
121 operands.push_back(constituentID);
122 }
123
124 encodeInstructionInto(typesGlobalValues,
125 spirv::Opcode::OpSpecConstantComposite, operands);
126 specConstIDMap[op.getSymName()] = resultID;
127
128 return processName(resultID, op.getSymName());
129}
130
131LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
132 spirv::EXTSpecConstantCompositeReplicateOp op) {
133 uint32_t typeID = 0;
134 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
135 return failure();
136 }
137
138 auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
139 if (!constituent)
140 return op.emitError(
141 "expected flat symbol reference for constituent instead of ")
142 << op.getConstituent();
143
144 StringRef constituentName = constituent.getValue();
145 uint32_t constituentID = getSpecConstID(constituentName);
146 if (!constituentID) {
147 return op.emitError("unknown result <id> for replicated spec constant ")
148 << constituentName;
149 }
150
151 uint32_t resultID = getNextID();
152 uint32_t operands[] = {typeID, resultID, constituentID};
153
154 encodeInstructionInto(typesGlobalValues,
155 spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
156 operands);
157
158 specConstIDMap[op.getSymName()] = resultID;
159
160 return processName(resultID, op.getSymName());
161}
162
163LogicalResult
164Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
165 uint32_t typeID = 0;
166 if (failed(processType(op.getLoc(), op.getType(), typeID))) {
167 return failure();
169
170 auto resultID = getNextID();
171
173 operands.push_back(typeID);
174 operands.push_back(resultID);
175
176 Block &block = op.getRegion().getBlocks().front();
177 Operation &enclosedOp = block.getOperations().front();
178
179 std::string enclosedOpName;
180 llvm::raw_string_ostream rss(enclosedOpName);
181 rss << "Op" << enclosedOp.getName().stripDialect();
182 auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
183
184 if (!enclosedOpcode) {
185 op.emitError("Couldn't find op code for op ")
186 << enclosedOp.getName().getStringRef();
187 return failure();
188 }
189
190 operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
191
192 // Append operands to the enclosed op to the list of operands.
193 for (Value operand : enclosedOp.getOperands()) {
194 uint32_t id = getValueID(operand);
195 assert(id && "use before def!");
196 operands.push_back(id);
197 }
198
199 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
200 operands);
201 valueIDMap[op.getResult()] = resultID;
202
203 return success();
204}
205
206LogicalResult
207Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
208 if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
209 op.getGraphConstantIdAttr())) {
210 valueIDMap[op.getResult()] = resultID;
211 return success();
212 }
213 return failure();
214}
215
216LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
217 auto undefType = op.getType();
218 auto &id = undefValIDMap[undefType];
219 if (!id) {
220 id = getNextID();
221 uint32_t typeID = 0;
222 if (failed(processType(op.getLoc(), undefType, typeID)))
223 return failure();
224 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
225 {typeID, id});
226 }
227 valueIDMap[op.getResult()] = id;
228 return success();
229}
230
231LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
232 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
233 uint32_t argTypeID = 0;
234 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
235 return failure();
236 }
237 auto argValueID = getNextID();
239 // Process decoration attributes of arguments.
240 auto funcOp = cast<FunctionOpInterface>(*op);
241 for (auto argAttr : funcOp.getArgAttrs(idx)) {
242 if (argAttr.getName() != DecorationAttr::name)
243 continue;
244
245 if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
246 if (failed(processDecorationAttr(op->getLoc(), argValueID,
247 decAttr.getValue(), decAttr)))
248 return failure();
249 }
250 }
251
252 valueIDMap[arg] = argValueID;
253 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
254 {argTypeID, argValueID});
255 }
256 return success();
257}
258
259LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
260 LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
261 assert(functionHeader.empty() && functionBody.empty());
262
263 uint32_t fnTypeID = 0;
264 // Generate type of the function.
265 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
266 return failure();
267
268 // Add the function definition.
269 SmallVector<uint32_t, 4> operands;
270 uint32_t resTypeID = 0;
271 auto resultTypes = op.getFunctionType().getResults();
272 if (resultTypes.size() > 1) {
273 return op.emitError("cannot serialize function with multiple return types");
274 }
275 if (failed(processType(op.getLoc(),
276 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
277 resTypeID))) {
278 return failure();
279 }
280 operands.push_back(resTypeID);
281 auto funcID = getOrCreateFunctionID(op.getName());
282 operands.push_back(funcID);
283 operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
284 operands.push_back(fnTypeID);
285 encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
286
287 // Add function name.
288 if (failed(processName(funcID, op.getName()))) {
289 return failure();
290 }
291 // Handle external functions with linkage_attributes(LinkageAttributes)
292 // differently.
293 auto linkageAttr = op.getLinkageAttributes();
294 auto hasImportLinkage =
295 linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
296 spirv::LinkageType::Import);
297 if (op.isExternal() && !hasImportLinkage) {
298 return op.emitError(
299 "'spirv.module' cannot contain external functions "
300 "without 'Import' linkage_attributes (LinkageAttributes)");
301 }
302 if (op.isExternal() && hasImportLinkage) {
303 // Add an entry block to set up the block arguments
304 // to match the signature of the function.
305 // This is to generate OpFunctionParameter for functions with
306 // LinkageAttributes.
307 // WARNING: This operation has side-effect, it essentially adds a body
308 // to the func. Hence, making it not external anymore (isExternal()
309 // is going to return false for this function from now on)
310 // Hence, we'll remove the body once we are done with the serialization.
311 op.addEntryBlock();
312 if (failed(processFuncParameter(op)))
313 return failure();
314
315 // Erasing the body of the function destroys arguments, so we need to remove
316 // them from the map to avoid problems when processing invalid values used
317 // as keys. We have already serialized function arguments so we probably can
318 // remove them from the map as external function will not have any uses.
319 for (Value arg : op.getArguments())
320 valueIDMap.erase(arg);
321
322 // Don't need to process the added block, there is nothing to process,
323 // the fake body was added just to get the arguments, remove the body,
324 // since it's use is done.
325 op.eraseBody();
326 } else {
327 if (failed(processFuncParameter(op)))
328 return failure();
329
330 // Some instructions (e.g., OpVariable) in a function must be in the first
331 // block in the function. These instructions will be put in
332 // functionHeader. Thus, we put the label in functionHeader first, and
333 // omit it from the first block. OpLabel only needs to be added for
334 // functions with body (including empty body). Since, we added a fake body
335 // for functions with 'Import' Linkage attributes, these functions are
336 // essentially function delcaration, so they should not have OpLabel and a
337 // terminating instruction. That's why we skipped it for those functions.
338 encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
339 {getOrCreateBlockID(&op.front())});
340 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
341 return failure();
343 &op.front(), [&](Block *block) { return processBlock(block); },
344 /*skipHeader=*/true))) {
345 return failure();
346 }
347
348 // There might be OpPhi instructions who have value references needing to
349 // fix.
350 for (const auto &deferredValue : deferredPhiValues) {
351 Value value = deferredValue.first;
352 uint32_t id = getValueID(value);
353 LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
354 << " to id = " << id << '\n');
355 assert(id && "OpPhi references undefined value!");
356 for (size_t offset : deferredValue.second)
357 functionBody[offset] = id;
358 }
359 deferredPhiValues.clear();
360 }
361 LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
362 << "' --\n");
363 // Insert Decorations based on Function Attributes.
364 // Only attributes we should be considering for decoration are the
365 // ::mlir::spirv::Decoration attributes.
366
367 for (auto attr : op->getAttrs()) {
368 // Only generate OpDecorate op for spirv::Decoration attributes.
369 auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
370 llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
371 /*capitalizeFirst=*/true));
372 if (isValidDecoration != std::nullopt) {
373 if (failed(processDecoration(op.getLoc(), funcID, attr))) {
374 return failure();
375 }
376 }
377 }
378 // Insert OpFunctionEnd.
379 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
380
381 functions.append(functionHeader.begin(), functionHeader.end());
382 functions.append(functionBody.begin(), functionBody.end());
383 functionHeader.clear();
384 functionBody.clear();
385
386 return success();
387}
388
389LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
390 if (op.getNumResults() < 1) {
391 return op.emitError("cannot serialize graph with no return types");
392 }
393
394 LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
395 assert(functionHeader.empty() && functionBody.empty());
396
397 uint32_t funcID = getOrCreateFunctionID(op.getName());
398 uint32_t fnTypeID = 0;
399 // Generate type of the function.
400 if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
401 return failure();
402 encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
403 {fnTypeID, funcID});
404
405 // Declare the parameters.
406 for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
407 uint32_t argTypeID = 0;
408 SmallVector<uint32_t, 3> inputOperands;
409
410 if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
411 return failure();
412 }
413
414 uint32_t argValueID = getNextID();
415 valueIDMap[arg] = argValueID;
416
417 auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
418 uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
419
420 inputOperands.push_back(argTypeID);
421 inputOperands.push_back(argValueID);
422 inputOperands.push_back(indexID);
423
424 encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
425 inputOperands);
426 }
427
428 if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
429 return failure();
430 if (failed(visitInPrettyBlockOrder(
431 &op.front(), [&](Block *block) { return processBlock(block); },
432 /*skipHeader=*/true))) {
433 return failure();
434 }
435
436 LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
437 << "' --\n");
438 // Insert OpGraphEndARM.
439 encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
440
441 llvm::append_range(graphs, functionHeader);
442 llvm::append_range(graphs, functionBody);
443 functionHeader.clear();
444 functionBody.clear();
445
446 return success();
447}
448
449LogicalResult
450Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
451 SmallVector<uint32_t, 4> operands;
452 StringRef graph = op.getFn();
453 // Add the graph <id>.
454 uint32_t graphID = getOrCreateFunctionID(graph);
455 operands.push_back(graphID);
456 // Add the name of the graph.
457 spirv::encodeStringLiteralInto(operands, graph);
458
459 // Add the interface values.
460 if (ArrayAttr interface = op.getInterface()) {
461 for (Attribute var : interface.getValue()) {
462 StringRef value = cast<FlatSymbolRefAttr>(var).getValue();
463 if (uint32_t id = getVariableID(value)) {
464 operands.push_back(id);
465 } else {
466 return op.emitError(
467 "referencing undefined global variable."
468 "spirv.GraphEntryPointARM is at the end of spirv.module. All "
469 "referenced variables should already be defined");
470 }
471 }
472 }
473 encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
474 return success();
475}
476
477LogicalResult
478Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
479 for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
480 SmallVector<uint32_t, 2> outputOperands;
481
482 Type resType = value.getType();
483 uint32_t resTypeID = 0;
484 if (failed(processType(op.getLoc(), resType, resTypeID))) {
485 return failure();
486 }
487
488 uint32_t outputID = getValueID(value);
489 auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
490 uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
491
492 outputOperands.push_back(outputID);
493 outputOperands.push_back(indexID);
494
495 encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
496 outputOperands);
497 }
498 return success();
499}
500
501LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
503 SmallVector<StringRef, 2> elidedAttrs;
504 uint32_t resultID = 0;
505 uint32_t resultTypeID = 0;
506 if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
507 return failure();
508 }
509 operands.push_back(resultTypeID);
510 resultID = getNextID();
511 valueIDMap[op.getResult()] = resultID;
512 operands.push_back(resultID);
513 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
514 if (attr) {
515 operands.push_back(
516 static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
517 }
518 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
519 for (auto arg : op.getODSOperands(0)) {
520 auto argID = getValueID(arg);
521 if (!argID) {
522 return emitError(op.getLoc(), "operand 0 has a use before def");
523 }
524 operands.push_back(argID);
525 }
526 if (failed(emitDebugLine(functionHeader, op.getLoc())))
527 return failure();
528 encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
529 for (auto attr : op->getAttrs()) {
530 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
531 return attr.getName() == elided;
532 })) {
533 continue;
535 if (failed(processDecoration(op.getLoc(), resultID, attr))) {
536 return failure();
537 }
538 }
539 return success();
540}
541
542LogicalResult
543Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
544 // Get TypeID.
545 uint32_t resultTypeID = 0;
546 SmallVector<StringRef, 4> elidedAttrs;
547 if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
548 return failure();
549 }
550
551 elidedAttrs.push_back("type");
552 SmallVector<uint32_t, 4> operands;
553 operands.push_back(resultTypeID);
554 auto resultID = getNextID();
555
556 // Encode the name.
557 auto varName = varOp.getSymName();
558 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
559 if (failed(processName(resultID, varName))) {
560 return failure();
561 }
562 globalVarIDMap[varName] = resultID;
563 operands.push_back(resultID);
564
565 // Encode StorageClass.
566 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
567
568 // Encode initialization.
569 StringRef initAttrName = varOp.getInitializerAttrName().getValue();
570 if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
571 uint32_t initializerID = 0;
572 auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
573 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
574 varOp->getParentOp(), initRef.getAttr());
575
576 // Check if initializer is GlobalVariable or SpecConstant* cases.
577 if (isa<spirv::GlobalVariableOp>(initOp))
578 initializerID = getVariableID(*initSymbolName);
579 else
580 initializerID = getSpecConstID(*initSymbolName);
581
582 if (!initializerID)
583 return emitError(varOp.getLoc(),
584 "invalid usage of undefined variable as initializer");
585
586 operands.push_back(initializerID);
587 elidedAttrs.push_back(initAttrName);
588 }
590 if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
591 return failure();
592 encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
593 elidedAttrs.push_back(initAttrName);
594
595 // Encode decorations.
596 for (auto attr : varOp->getAttrs()) {
597 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
598 return attr.getName() == elided;
599 })) {
600 continue;
601 }
602 if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
603 return failure();
604 }
605 }
606 return success();
607}
608
609LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
610 // Assign <id>s to all blocks so that branches inside the SelectionOp can
611 // resolve properly.
612 auto &body = selectionOp.getBody();
613 for (Block &block : body)
614 getOrCreateBlockID(&block);
615
616 auto *headerBlock = selectionOp.getHeaderBlock();
617 auto *mergeBlock = selectionOp.getMergeBlock();
618 auto headerID = getBlockID(headerBlock);
619 auto mergeID = getBlockID(mergeBlock);
620 auto loc = selectionOp.getLoc();
621
622 // Before we do anything replace results of the selection operation with
623 // values yielded (with `mlir.merge`) from inside the region. The selection op
624 // is being flattened so we do not have to worry about values being defined
625 // inside a region and used outside it anymore.
626 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
627 assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
628 for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
629 selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
630
631 // This SelectionOp is in some MLIR block with preceding and following ops. In
632 // the binary format, it should reside in separate SPIR-V blocks from its
633 // preceding and following ops. So we need to emit unconditional branches to
634 // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
635 // flow afterwards.
636 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
637
638 // Emit the selection header block, which dominates all other blocks, first.
639 // We need to emit an OpSelectionMerge instruction before the selection header
640 // block's terminator.
641 auto emitSelectionMerge = [&]() {
642 if (failed(emitDebugLine(functionBody, loc)))
643 return failure();
644 lastProcessedWasMergeInst = true;
646 functionBody, spirv::Opcode::OpSelectionMerge,
647 {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
648 return success();
649 };
650 if (failed(
651 processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
652 return failure();
653
654 // Process all blocks with a depth-first visitor starting from the header
655 // block. The selection header block and merge block are skipped by this
656 // visitor.
657 if (failed(visitInPrettyBlockOrder(
658 headerBlock, [&](Block *block) { return processBlock(block); },
659 /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
660 return failure();
661
662 // There is nothing to do for the merge block in the selection, which just
663 // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
664 // instruction to start a new SPIR-V block for ops following this SelectionOp.
665 // The block should use the <id> for the merge block.
666 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
667
668 // We do not process the mergeBlock but we still need to generate phi
669 // functions from its block arguments.
670 if (failed(emitPhiForBlockArguments(mergeBlock)))
671 return failure();
672
673 LLVM_DEBUG(llvm::dbgs() << "done merge ");
674 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
675 LLVM_DEBUG(llvm::dbgs() << "\n");
676 return success();
677}
678
679LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
680 // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
681 // properly. We don't need to assign for the entry block, which is just for
682 // satisfying MLIR region's structural requirement.
683 auto &body = loopOp.getBody();
684 for (Block &block : llvm::drop_begin(body))
685 getOrCreateBlockID(&block);
686
687 auto *headerBlock = loopOp.getHeaderBlock();
688 auto *continueBlock = loopOp.getContinueBlock();
689 auto *mergeBlock = loopOp.getMergeBlock();
690 auto headerID = getBlockID(headerBlock);
691 auto continueID = getBlockID(continueBlock);
692 auto mergeID = getBlockID(mergeBlock);
693 auto loc = loopOp.getLoc();
694
695 // Before we do anything replace results of the selection operation with
696 // values yielded (with `mlir.merge`) from inside the region.
697 auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
698 assert(loopOp.getNumResults() == mergeOp.getNumOperands());
699 for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
700 loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
701
702 // This LoopOp is in some MLIR block with preceding and following ops. In the
703 // binary format, it should reside in separate SPIR-V blocks from its
704 // preceding and following ops. So we need to emit unconditional branches to
705 // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
706 // afterwards.
707 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
708
709 // LoopOp's entry block is just there for satisfying MLIR's structural
710 // requirements so we omit it and start serialization from the loop header
711 // block.
712
713 // Emit the loop header block, which dominates all other blocks, first. We
714 // need to emit an OpLoopMerge instruction before the loop header block's
715 // terminator.
716 auto emitLoopMerge = [&]() {
717 if (failed(emitDebugLine(functionBody, loc)))
718 return failure();
719 lastProcessedWasMergeInst = true;
721 functionBody, spirv::Opcode::OpLoopMerge,
722 {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
723 return success();
724 };
725 if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
726 return failure();
727
728 // Process all blocks with a depth-first visitor starting from the header
729 // block. The loop header block, loop continue block, and loop merge block are
730 // skipped by this visitor and handled later in this function.
732 headerBlock, [&](Block *block) { return processBlock(block); },
733 /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
734 return failure();
735
736 // We have handled all other blocks. Now get to the loop continue block.
737 if (failed(processBlock(continueBlock)))
738 return failure();
739
740 // There is nothing to do for the merge block in the loop, which just contains
741 // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
742 // to start a new SPIR-V block for ops following this LoopOp. The block should
743 // use the <id> for the merge block.
744 encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
745 LLVM_DEBUG(llvm::dbgs() << "done merge ");
746 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
747 LLVM_DEBUG(llvm::dbgs() << "\n");
748 return success();
749}
750
751LogicalResult Serializer::processBranchConditionalOp(
752 spirv::BranchConditionalOp condBranchOp) {
753 auto conditionID = getValueID(condBranchOp.getCondition());
754 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
755 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
756 SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
757
758 if (auto weights = condBranchOp.getBranchWeights()) {
759 for (auto val : weights->getValue())
760 arguments.push_back(cast<IntegerAttr>(val).getInt());
761 }
762
763 if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
764 return failure();
765 encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
766 arguments);
767 return success();
768}
769
770LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
771 if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
772 return failure();
773 encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
774 {getOrCreateBlockID(branchOp.getTarget())});
775 return success();
776}
777
778LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
779 auto varName = addressOfOp.getVariable();
780 auto variableID = getVariableID(varName);
781 if (!variableID) {
782 return addressOfOp.emitError("unknown result <id> for variable ")
783 << varName;
784 }
785 valueIDMap[addressOfOp.getPointer()] = variableID;
786 return success();
787}
788
789LogicalResult
790Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
791 auto constName = referenceOfOp.getSpecConst();
792 auto constID = getSpecConstID(constName);
793 if (!constID) {
794 return referenceOfOp.emitError(
795 "unknown result <id> for specialization constant ")
796 << constName;
797 }
798 valueIDMap[referenceOfOp.getReference()] = constID;
799 return success();
800}
801
802template <>
803LogicalResult
804Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
806 // Add the ExecutionModel.
807 operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
808 // Add the function <id>.
809 auto funcID = getFunctionID(op.getFn());
810 if (!funcID) {
811 return op.emitError("missing <id> for function ")
812 << op.getFn()
813 << "; function needs to be defined before spirv.EntryPoint is "
814 "serialized";
815 }
816 operands.push_back(funcID);
817 // Add the name of the function.
818 spirv::encodeStringLiteralInto(operands, op.getFn());
820 // Add the interface values.
821 if (auto interface = op.getInterface()) {
822 for (auto var : interface.getValue()) {
823 auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
824 if (!id) {
825 return op.emitError(
826 "referencing undefined global variable."
827 "spirv.EntryPoint is at the end of spirv.module. All "
828 "referenced variables should already be defined");
830 operands.push_back(id);
831 }
832 }
833 encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
834 return success();
835}
836
837template <>
838LogicalResult
839Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
841 // Add the function <id>.
842 auto funcID = getFunctionID(op.getFn());
843 if (!funcID) {
844 return op.emitError("missing <id> for function ")
845 << op.getFn()
846 << "; function needs to be serialized before ExecutionModeOp is "
847 "serialized";
848 }
849 operands.push_back(funcID);
850 // Add the ExecutionMode.
851 operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
852
853 // Serialize values if any.
854 auto values = op.getValues();
855 if (values) {
856 for (auto &intVal : values.getValue()) {
857 operands.push_back(static_cast<uint32_t>(
858 llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
860 }
861 encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
862 operands);
863 return success();
865
866template <>
867LogicalResult
868Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
869 auto funcName = op.getCallee();
870 uint32_t resTypeID = 0;
871
872 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
873 if (failed(processType(op.getLoc(), resultTy, resTypeID)))
874 return failure();
875
876 auto funcID = getOrCreateFunctionID(funcName);
877 auto funcCallID = getNextID();
878 SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
880 for (auto value : op.getArguments()) {
881 auto valueID = getValueID(value);
882 assert(valueID && "cannot find a value for spirv.FunctionCall");
883 operands.push_back(valueID);
885
886 if (!isa<NoneType>(resultTy))
887 valueIDMap[op.getResult(0)] = funcCallID;
888
889 encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
890 return success();
891}
892
893template <>
894LogicalResult
895Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
897 SmallVector<StringRef, 2> elidedAttrs;
898
899 for (Value operand : op->getOperands()) {
900 auto id = getValueID(operand);
901 assert(id && "use before def!");
902 operands.push_back(id);
903 }
905 StringAttr memoryAccess = op.getMemoryAccessAttrName();
906 if (auto attr = op->getAttr(memoryAccess)) {
907 operands.push_back(
908 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
910
911 elidedAttrs.push_back(memoryAccess.strref());
912
913 StringAttr alignment = op.getAlignmentAttrName();
914 if (auto attr = op->getAttr(alignment)) {
915 operands.push_back(static_cast<uint32_t>(
916 cast<IntegerAttr>(attr).getValue().getZExtValue()));
917 }
918
919 elidedAttrs.push_back(alignment.strref());
920
921 StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
922 if (auto attr = op->getAttr(sourceMemoryAccess)) {
923 operands.push_back(
924 static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
925 }
926
927 elidedAttrs.push_back(sourceMemoryAccess.strref());
928
929 StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
930 if (auto attr = op->getAttr(sourceAlignment)) {
931 operands.push_back(static_cast<uint32_t>(
932 cast<IntegerAttr>(attr).getValue().getZExtValue()));
933 }
935 elidedAttrs.push_back(sourceAlignment.strref());
936 if (failed(emitDebugLine(functionBody, op.getLoc())))
937 return failure();
938 encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
940 return success();
941}
942template <>
943LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
944 spirv::GenericCastToPtrExplicitOp op) {
946 Type resultTy;
947 Location loc = op->getLoc();
948 uint32_t resultTypeID = 0;
949 uint32_t resultID = 0;
950 resultTy = op->getResult(0).getType();
951 if (failed(processType(loc, resultTy, resultTypeID)))
952 return failure();
953 operands.push_back(resultTypeID);
955 resultID = getNextID();
956 operands.push_back(resultID);
957 valueIDMap[op->getResult(0)] = resultID;
958
959 for (Value operand : op->getOperands())
960 operands.push_back(getValueID(operand));
961 spirv::StorageClass resultStorage =
962 cast<spirv::PointerType>(resultTy).getStorageClass();
963 operands.push_back(static_cast<uint32_t>(resultStorage));
964 encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
965 operands);
966 return success();
967}
968
969// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
970// various Serializer::processOp<...>() specializations.
971#define GET_SERIALIZATION_FNS
972#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
973
974} // namespace spirv
975} // namespace mlir
return success()
ArrayAttr()
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
constexpr StringRef attributeName()
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152