MLIR  21.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "spirv-serialization"
24 
25 using namespace mlir;
26 
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
44  function_ref<LogicalResult(Block *)> blockHandler,
45  bool skipHeader = false, BlockRange skipBlocks = {}) {
46  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 
49  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50  if (skipHeader && block == headerBlock)
51  continue;
52  if (failed(blockHandler(block)))
53  return failure();
54  }
55  return success();
56 }
57 
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61  if (auto resultID =
62  prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63  valueIDMap[op.getResult()] = resultID;
64  return success();
65  }
66  return failure();
67 }
68 
69 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
70  if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
71  /*isSpec=*/true)) {
72  // Emit the OpDecorate instruction for SpecId.
73  if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
74  auto val = static_cast<uint32_t>(specID.getInt());
75  if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
76  return failure();
77  }
78 
79  specConstIDMap[op.getSymName()] = resultID;
80  return processName(resultID, op.getSymName());
81  }
82  return failure();
83 }
84 
85 LogicalResult
86 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
87  uint32_t typeID = 0;
88  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
89  return failure();
90  }
91 
92  auto resultID = getNextID();
93 
94  SmallVector<uint32_t, 8> operands;
95  operands.push_back(typeID);
96  operands.push_back(resultID);
97 
98  auto constituents = op.getConstituents();
99 
100  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101  auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
102 
103  auto constituentName = constituent.getValue();
104  auto constituentID = getSpecConstID(constituentName);
105 
106  if (!constituentID) {
107  return op.emitError("unknown result <id> for specialization constant ")
108  << constituentName;
109  }
110 
111  operands.push_back(constituentID);
112  }
113 
114  encodeInstructionInto(typesGlobalValues,
115  spirv::Opcode::OpSpecConstantComposite, operands);
116  specConstIDMap[op.getSymName()] = resultID;
117 
118  return processName(resultID, op.getSymName());
119 }
120 
121 LogicalResult
122 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
123  uint32_t typeID = 0;
124  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
125  return failure();
126  }
127 
128  auto resultID = getNextID();
129 
130  SmallVector<uint32_t, 8> operands;
131  operands.push_back(typeID);
132  operands.push_back(resultID);
133 
134  Block &block = op.getRegion().getBlocks().front();
135  Operation &enclosedOp = block.getOperations().front();
136 
137  std::string enclosedOpName;
138  llvm::raw_string_ostream rss(enclosedOpName);
139  rss << "Op" << enclosedOp.getName().stripDialect();
140  auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
141 
142  if (!enclosedOpcode) {
143  op.emitError("Couldn't find op code for op ")
144  << enclosedOp.getName().getStringRef();
145  return failure();
146  }
147 
148  operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
149 
150  // Append operands to the enclosed op to the list of operands.
151  for (Value operand : enclosedOp.getOperands()) {
152  uint32_t id = getValueID(operand);
153  assert(id && "use before def!");
154  operands.push_back(id);
155  }
156 
157  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
158  operands);
159  valueIDMap[op.getResult()] = resultID;
160 
161  return success();
162 }
163 
164 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
165  auto undefType = op.getType();
166  auto &id = undefValIDMap[undefType];
167  if (!id) {
168  id = getNextID();
169  uint32_t typeID = 0;
170  if (failed(processType(op.getLoc(), undefType, typeID)))
171  return failure();
172  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
173  {typeID, id});
174  }
175  valueIDMap[op.getResult()] = id;
176  return success();
177 }
178 
179 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
180  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
181  uint32_t argTypeID = 0;
182  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
183  return failure();
184  }
185  auto argValueID = getNextID();
186 
187  // Process decoration attributes of arguments.
188  auto funcOp = cast<FunctionOpInterface>(*op);
189  for (auto argAttr : funcOp.getArgAttrs(idx)) {
190  if (argAttr.getName() != DecorationAttr::name)
191  continue;
192 
193  if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
194  if (failed(processDecorationAttr(op->getLoc(), argValueID,
195  decAttr.getValue(), decAttr)))
196  return failure();
197  }
198  }
199 
200  valueIDMap[arg] = argValueID;
201  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
202  {argTypeID, argValueID});
203  }
204  return success();
205 }
206 
207 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
208  LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
209  assert(functionHeader.empty() && functionBody.empty());
210 
211  uint32_t fnTypeID = 0;
212  // Generate type of the function.
213  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
214  return failure();
215 
216  // Add the function definition.
217  SmallVector<uint32_t, 4> operands;
218  uint32_t resTypeID = 0;
219  auto resultTypes = op.getFunctionType().getResults();
220  if (resultTypes.size() > 1) {
221  return op.emitError("cannot serialize function with multiple return types");
222  }
223  if (failed(processType(op.getLoc(),
224  (resultTypes.empty() ? getVoidType() : resultTypes[0]),
225  resTypeID))) {
226  return failure();
227  }
228  operands.push_back(resTypeID);
229  auto funcID = getOrCreateFunctionID(op.getName());
230  operands.push_back(funcID);
231  operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
232  operands.push_back(fnTypeID);
233  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
234 
235  // Add function name.
236  if (failed(processName(funcID, op.getName()))) {
237  return failure();
238  }
239  // Handle external functions with linkage_attributes(LinkageAttributes)
240  // differently.
241  auto linkageAttr = op.getLinkageAttributes();
242  auto hasImportLinkage =
243  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
244  spirv::LinkageType::Import);
245  if (op.isExternal() && !hasImportLinkage) {
246  return op.emitError(
247  "'spirv.module' cannot contain external functions "
248  "without 'Import' linkage_attributes (LinkageAttributes)");
249  }
250  if (op.isExternal() && hasImportLinkage) {
251  // Add an entry block to set up the block arguments
252  // to match the signature of the function.
253  // This is to generate OpFunctionParameter for functions with
254  // LinkageAttributes.
255  // WARNING: This operation has side-effect, it essentially adds a body
256  // to the func. Hence, making it not external anymore (isExternal()
257  // is going to return false for this function from now on)
258  // Hence, we'll remove the body once we are done with the serialization.
259  op.addEntryBlock();
260  if (failed(processFuncParameter(op)))
261  return failure();
262  // Don't need to process the added block, there is nothing to process,
263  // the fake body was added just to get the arguments, remove the body,
264  // since it's use is done.
265  op.eraseBody();
266  } else {
267  if (failed(processFuncParameter(op)))
268  return failure();
269 
270  // Some instructions (e.g., OpVariable) in a function must be in the first
271  // block in the function. These instructions will be put in
272  // functionHeader. Thus, we put the label in functionHeader first, and
273  // omit it from the first block. OpLabel only needs to be added for
274  // functions with body (including empty body). Since, we added a fake body
275  // for functions with 'Import' Linkage attributes, these functions are
276  // essentially function delcaration, so they should not have OpLabel and a
277  // terminating instruction. That's why we skipped it for those functions.
278  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
279  {getOrCreateBlockID(&op.front())});
280  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
281  return failure();
282  if (failed(visitInPrettyBlockOrder(
283  &op.front(), [&](Block *block) { return processBlock(block); },
284  /*skipHeader=*/true))) {
285  return failure();
286  }
287 
288  // There might be OpPhi instructions who have value references needing to
289  // fix.
290  for (const auto &deferredValue : deferredPhiValues) {
291  Value value = deferredValue.first;
292  uint32_t id = getValueID(value);
293  LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
294  << " to id = " << id << '\n');
295  assert(id && "OpPhi references undefined value!");
296  for (size_t offset : deferredValue.second)
297  functionBody[offset] = id;
298  }
299  deferredPhiValues.clear();
300  }
301  LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
302  << "' --\n");
303  // Insert Decorations based on Function Attributes.
304  // Only attributes we should be considering for decoration are the
305  // ::mlir::spirv::Decoration attributes.
306 
307  for (auto attr : op->getAttrs()) {
308  // Only generate OpDecorate op for spirv::Decoration attributes.
309  auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
310  llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
311  /*capitalizeFirst=*/true));
312  if (isValidDecoration != std::nullopt) {
313  if (failed(processDecoration(op.getLoc(), funcID, attr))) {
314  return failure();
315  }
316  }
317  }
318  // Insert OpFunctionEnd.
319  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
320 
321  functions.append(functionHeader.begin(), functionHeader.end());
322  functions.append(functionBody.begin(), functionBody.end());
323  functionHeader.clear();
324  functionBody.clear();
325 
326  return success();
327 }
328 
329 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
330  SmallVector<uint32_t, 4> operands;
331  SmallVector<StringRef, 2> elidedAttrs;
332  uint32_t resultID = 0;
333  uint32_t resultTypeID = 0;
334  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
335  return failure();
336  }
337  operands.push_back(resultTypeID);
338  resultID = getNextID();
339  valueIDMap[op.getResult()] = resultID;
340  operands.push_back(resultID);
341  auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
342  if (attr) {
343  operands.push_back(
344  static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
345  }
346  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
347  for (auto arg : op.getODSOperands(0)) {
348  auto argID = getValueID(arg);
349  if (!argID) {
350  return emitError(op.getLoc(), "operand 0 has a use before def");
351  }
352  operands.push_back(argID);
353  }
354  if (failed(emitDebugLine(functionHeader, op.getLoc())))
355  return failure();
356  encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
357  for (auto attr : op->getAttrs()) {
358  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
359  return attr.getName() == elided;
360  })) {
361  continue;
362  }
363  if (failed(processDecoration(op.getLoc(), resultID, attr))) {
364  return failure();
365  }
366  }
367  return success();
368 }
369 
370 LogicalResult
371 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
372  // Get TypeID.
373  uint32_t resultTypeID = 0;
374  SmallVector<StringRef, 4> elidedAttrs;
375  if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
376  return failure();
377  }
378 
379  elidedAttrs.push_back("type");
380  SmallVector<uint32_t, 4> operands;
381  operands.push_back(resultTypeID);
382  auto resultID = getNextID();
383 
384  // Encode the name.
385  auto varName = varOp.getSymName();
386  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
387  if (failed(processName(resultID, varName))) {
388  return failure();
389  }
390  globalVarIDMap[varName] = resultID;
391  operands.push_back(resultID);
392 
393  // Encode StorageClass.
394  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
395 
396  // Encode initialization.
397  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
398  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
399  uint32_t initializerID = 0;
400  auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
402  varOp->getParentOp(), initRef.getAttr());
403 
404  // Check if initializer is GlobalVariable or SpecConstant* cases.
405  if (isa<spirv::GlobalVariableOp>(initOp))
406  initializerID = getVariableID(*initSymbolName);
407  else
408  initializerID = getSpecConstID(*initSymbolName);
409 
410  if (!initializerID)
411  return emitError(varOp.getLoc(),
412  "invalid usage of undefined variable as initializer");
413 
414  operands.push_back(initializerID);
415  elidedAttrs.push_back(initAttrName);
416  }
417 
418  if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
419  return failure();
420  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
421  elidedAttrs.push_back(initAttrName);
422 
423  // Encode decorations.
424  for (auto attr : varOp->getAttrs()) {
425  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
426  return attr.getName() == elided;
427  })) {
428  continue;
429  }
430  if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
431  return failure();
432  }
433  }
434  return success();
435 }
436 
437 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
438  // Assign <id>s to all blocks so that branches inside the SelectionOp can
439  // resolve properly.
440  auto &body = selectionOp.getBody();
441  for (Block &block : body)
442  getOrCreateBlockID(&block);
443 
444  auto *headerBlock = selectionOp.getHeaderBlock();
445  auto *mergeBlock = selectionOp.getMergeBlock();
446  auto headerID = getBlockID(headerBlock);
447  auto mergeID = getBlockID(mergeBlock);
448  auto loc = selectionOp.getLoc();
449 
450  // Before we do anything replace results of the selection operation with
451  // values yielded (with `mlir.merge`) from inside the region. The selection op
452  // is being flattened so we do not have to worry about values being defined
453  // inside a region and used outside it anymore.
454  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
455  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
456  for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
457  selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
458 
459  // This SelectionOp is in some MLIR block with preceding and following ops. In
460  // the binary format, it should reside in separate SPIR-V blocks from its
461  // preceding and following ops. So we need to emit unconditional branches to
462  // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
463  // flow afterwards.
464  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
465 
466  // Emit the selection header block, which dominates all other blocks, first.
467  // We need to emit an OpSelectionMerge instruction before the selection header
468  // block's terminator.
469  auto emitSelectionMerge = [&]() {
470  if (failed(emitDebugLine(functionBody, loc)))
471  return failure();
472  lastProcessedWasMergeInst = true;
474  functionBody, spirv::Opcode::OpSelectionMerge,
475  {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
476  return success();
477  };
478  if (failed(
479  processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
480  return failure();
481 
482  // Process all blocks with a depth-first visitor starting from the header
483  // block. The selection header block and merge block are skipped by this
484  // visitor.
485  if (failed(visitInPrettyBlockOrder(
486  headerBlock, [&](Block *block) { return processBlock(block); },
487  /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
488  return failure();
489 
490  // There is nothing to do for the merge block in the selection, which just
491  // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
492  // instruction to start a new SPIR-V block for ops following this SelectionOp.
493  // The block should use the <id> for the merge block.
494  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
495 
496  // We do not process the mergeBlock but we still need to generate phi
497  // functions from its block arguments.
498  if (failed(emitPhiForBlockArguments(mergeBlock)))
499  return failure();
500 
501  LLVM_DEBUG(llvm::dbgs() << "done merge ");
502  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
503  LLVM_DEBUG(llvm::dbgs() << "\n");
504  return success();
505 }
506 
507 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
508  // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
509  // properly. We don't need to assign for the entry block, which is just for
510  // satisfying MLIR region's structural requirement.
511  auto &body = loopOp.getBody();
512  for (Block &block : llvm::drop_begin(body))
513  getOrCreateBlockID(&block);
514 
515  auto *headerBlock = loopOp.getHeaderBlock();
516  auto *continueBlock = loopOp.getContinueBlock();
517  auto *mergeBlock = loopOp.getMergeBlock();
518  auto headerID = getBlockID(headerBlock);
519  auto continueID = getBlockID(continueBlock);
520  auto mergeID = getBlockID(mergeBlock);
521  auto loc = loopOp.getLoc();
522 
523  // This LoopOp is in some MLIR block with preceding and following ops. In the
524  // binary format, it should reside in separate SPIR-V blocks from its
525  // preceding and following ops. So we need to emit unconditional branches to
526  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
527  // afterwards.
528  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
529 
530  // LoopOp's entry block is just there for satisfying MLIR's structural
531  // requirements so we omit it and start serialization from the loop header
532  // block.
533 
534  // Emit the loop header block, which dominates all other blocks, first. We
535  // need to emit an OpLoopMerge instruction before the loop header block's
536  // terminator.
537  auto emitLoopMerge = [&]() {
538  if (failed(emitDebugLine(functionBody, loc)))
539  return failure();
540  lastProcessedWasMergeInst = true;
542  functionBody, spirv::Opcode::OpLoopMerge,
543  {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
544  return success();
545  };
546  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
547  return failure();
548 
549  // Process all blocks with a depth-first visitor starting from the header
550  // block. The loop header block, loop continue block, and loop merge block are
551  // skipped by this visitor and handled later in this function.
552  if (failed(visitInPrettyBlockOrder(
553  headerBlock, [&](Block *block) { return processBlock(block); },
554  /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
555  return failure();
556 
557  // We have handled all other blocks. Now get to the loop continue block.
558  if (failed(processBlock(continueBlock)))
559  return failure();
560 
561  // There is nothing to do for the merge block in the loop, which just contains
562  // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
563  // to start a new SPIR-V block for ops following this LoopOp. The block should
564  // use the <id> for the merge block.
565  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
566  LLVM_DEBUG(llvm::dbgs() << "done merge ");
567  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
568  LLVM_DEBUG(llvm::dbgs() << "\n");
569  return success();
570 }
571 
572 LogicalResult Serializer::processBranchConditionalOp(
573  spirv::BranchConditionalOp condBranchOp) {
574  auto conditionID = getValueID(condBranchOp.getCondition());
575  auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
576  auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
577  SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
578 
579  if (auto weights = condBranchOp.getBranchWeights()) {
580  for (auto val : weights->getValue())
581  arguments.push_back(cast<IntegerAttr>(val).getInt());
582  }
583 
584  if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
585  return failure();
586  encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
587  arguments);
588  return success();
589 }
590 
591 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
592  if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
593  return failure();
594  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
595  {getOrCreateBlockID(branchOp.getTarget())});
596  return success();
597 }
598 
599 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
600  auto varName = addressOfOp.getVariable();
601  auto variableID = getVariableID(varName);
602  if (!variableID) {
603  return addressOfOp.emitError("unknown result <id> for variable ")
604  << varName;
605  }
606  valueIDMap[addressOfOp.getPointer()] = variableID;
607  return success();
608 }
609 
610 LogicalResult
611 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
612  auto constName = referenceOfOp.getSpecConst();
613  auto constID = getSpecConstID(constName);
614  if (!constID) {
615  return referenceOfOp.emitError(
616  "unknown result <id> for specialization constant ")
617  << constName;
618  }
619  valueIDMap[referenceOfOp.getReference()] = constID;
620  return success();
621 }
622 
623 template <>
624 LogicalResult
625 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
626  SmallVector<uint32_t, 4> operands;
627  // Add the ExecutionModel.
628  operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
629  // Add the function <id>.
630  auto funcID = getFunctionID(op.getFn());
631  if (!funcID) {
632  return op.emitError("missing <id> for function ")
633  << op.getFn()
634  << "; function needs to be defined before spirv.EntryPoint is "
635  "serialized";
636  }
637  operands.push_back(funcID);
638  // Add the name of the function.
639  spirv::encodeStringLiteralInto(operands, op.getFn());
640 
641  // Add the interface values.
642  if (auto interface = op.getInterface()) {
643  for (auto var : interface.getValue()) {
644  auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
645  if (!id) {
646  return op.emitError(
647  "referencing undefined global variable."
648  "spirv.EntryPoint is at the end of spirv.module. All "
649  "referenced variables should already be defined");
650  }
651  operands.push_back(id);
652  }
653  }
654  encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
655  return success();
656 }
657 
658 template <>
659 LogicalResult
660 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
661  SmallVector<uint32_t, 4> operands;
662  // Add the function <id>.
663  auto funcID = getFunctionID(op.getFn());
664  if (!funcID) {
665  return op.emitError("missing <id> for function ")
666  << op.getFn()
667  << "; function needs to be serialized before ExecutionModeOp is "
668  "serialized";
669  }
670  operands.push_back(funcID);
671  // Add the ExecutionMode.
672  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
673 
674  // Serialize values if any.
675  auto values = op.getValues();
676  if (values) {
677  for (auto &intVal : values.getValue()) {
678  operands.push_back(static_cast<uint32_t>(
679  llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
680  }
681  }
682  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
683  operands);
684  return success();
685 }
686 
687 template <>
688 LogicalResult
689 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
690  auto funcName = op.getCallee();
691  uint32_t resTypeID = 0;
692 
693  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
694  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
695  return failure();
696 
697  auto funcID = getOrCreateFunctionID(funcName);
698  auto funcCallID = getNextID();
699  SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
700 
701  for (auto value : op.getArguments()) {
702  auto valueID = getValueID(value);
703  assert(valueID && "cannot find a value for spirv.FunctionCall");
704  operands.push_back(valueID);
705  }
706 
707  if (!isa<NoneType>(resultTy))
708  valueIDMap[op.getResult(0)] = funcCallID;
709 
710  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
711  return success();
712 }
713 
714 template <>
715 LogicalResult
716 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
717  SmallVector<uint32_t, 4> operands;
718  SmallVector<StringRef, 2> elidedAttrs;
719 
720  for (Value operand : op->getOperands()) {
721  auto id = getValueID(operand);
722  assert(id && "use before def!");
723  operands.push_back(id);
724  }
725 
726  StringAttr memoryAccess = op.getMemoryAccessAttrName();
727  if (auto attr = op->getAttr(memoryAccess)) {
728  operands.push_back(
729  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
730  }
731 
732  elidedAttrs.push_back(memoryAccess.strref());
733 
734  StringAttr alignment = op.getAlignmentAttrName();
735  if (auto attr = op->getAttr(alignment)) {
736  operands.push_back(static_cast<uint32_t>(
737  cast<IntegerAttr>(attr).getValue().getZExtValue()));
738  }
739 
740  elidedAttrs.push_back(alignment.strref());
741 
742  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
743  if (auto attr = op->getAttr(sourceMemoryAccess)) {
744  operands.push_back(
745  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
746  }
747 
748  elidedAttrs.push_back(sourceMemoryAccess.strref());
749 
750  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
751  if (auto attr = op->getAttr(sourceAlignment)) {
752  operands.push_back(static_cast<uint32_t>(
753  cast<IntegerAttr>(attr).getValue().getZExtValue()));
754  }
755 
756  elidedAttrs.push_back(sourceAlignment.strref());
757  if (failed(emitDebugLine(functionBody, op.getLoc())))
758  return failure();
759  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
760 
761  return success();
762 }
763 template <>
764 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
765  spirv::GenericCastToPtrExplicitOp op) {
766  SmallVector<uint32_t, 4> operands;
767  Type resultTy;
768  Location loc = op->getLoc();
769  uint32_t resultTypeID = 0;
770  uint32_t resultID = 0;
771  resultTy = op->getResult(0).getType();
772  if (failed(processType(loc, resultTy, resultTypeID)))
773  return failure();
774  operands.push_back(resultTypeID);
775 
776  resultID = getNextID();
777  operands.push_back(resultID);
778  valueIDMap[op->getResult(0)] = resultID;
779 
780  for (Value operand : op->getOperands())
781  operands.push_back(getValueID(operand));
782  spirv::StorageClass resultStorage =
783  cast<spirv::PointerType>(resultTy).getStorageClass();
784  operands.push_back(static_cast<uint32_t>(resultStorage));
785  encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
786  operands);
787  return success();
788 }
789 
790 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
791 // various Serializer::processOp<...>() specializations.
792 #define GET_SERIALIZATION_FNS
793 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
794 
795 } // namespace spirv
796 } // namespace mlir
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:78
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.