MLIR  22.0.0git
SerializeOps.cpp
Go to the documentation of this file.
1 //===- SerializeOps.cpp - MLIR SPIR-V Serialization (Ops) -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the serialization methods for MLIR SPIR-V module ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Serializer.h"
14 
19 #include "llvm/ADT/DepthFirstIterator.h"
20 #include "llvm/ADT/StringExtras.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "spirv-serialization"
24 
25 using namespace mlir;
26 
27 /// A pre-order depth-first visitor function for processing basic blocks.
28 ///
29 /// Visits the basic blocks starting from the given `headerBlock` in pre-order
30 /// depth-first manner and calls `blockHandler` on each block. Skips handling
31 /// blocks in the `skipBlocks` list. If `skipHeader` is true, `blockHandler`
32 /// will not be invoked in `headerBlock` but still handles all `headerBlock`'s
33 /// successors.
34 ///
35 /// SPIR-V spec "2.16.1. Universal Validation Rules" requires that "the order
36 /// of blocks in a function must satisfy the rule that blocks appear before
37 /// all blocks they dominate." This can be achieved by a pre-order CFG
38 /// traversal algorithm. To make the serialization output more logical and
39 /// readable to human, we perform depth-first CFG traversal and delay the
40 /// serialization of the merge block and the continue block, if exists, until
41 /// after all other blocks have been processed.
42 static LogicalResult
44  function_ref<LogicalResult(Block *)> blockHandler,
45  bool skipHeader = false, BlockRange skipBlocks = {}) {
46  llvm::df_iterator_default_set<Block *, 4> doneBlocks;
47  doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 
49  for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
50  if (skipHeader && block == headerBlock)
51  continue;
52  if (failed(blockHandler(block)))
53  return failure();
54  }
55  return success();
56 }
57 
58 namespace mlir {
59 namespace spirv {
60 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
61  if (auto resultID =
62  prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
63  valueIDMap[op.getResult()] = resultID;
64  return success();
65  }
66  return failure();
67 }
68 
69 LogicalResult Serializer::processConstantCompositeReplicateOp(
70  spirv::EXTConstantCompositeReplicateOp op) {
71  if (uint32_t resultID = prepareConstantCompositeReplicate(
72  op.getLoc(), op.getType(), op.getValue())) {
73  valueIDMap[op.getResult()] = resultID;
74  return success();
75  }
76  return failure();
77 }
78 
79 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
80  if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
81  /*isSpec=*/true)) {
82  // Emit the OpDecorate instruction for SpecId.
83  if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
84  auto val = static_cast<uint32_t>(specID.getInt());
85  if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
86  return failure();
87  }
88 
89  specConstIDMap[op.getSymName()] = resultID;
90  return processName(resultID, op.getSymName());
91  }
92  return failure();
93 }
94 
95 LogicalResult
96 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
97  uint32_t typeID = 0;
98  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
99  return failure();
100  }
101 
102  auto resultID = getNextID();
103 
104  SmallVector<uint32_t, 8> operands;
105  operands.push_back(typeID);
106  operands.push_back(resultID);
107 
108  auto constituents = op.getConstituents();
109 
110  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
111  auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
112 
113  auto constituentName = constituent.getValue();
114  auto constituentID = getSpecConstID(constituentName);
115 
116  if (!constituentID) {
117  return op.emitError("unknown result <id> for specialization constant ")
118  << constituentName;
119  }
120 
121  operands.push_back(constituentID);
122  }
123 
124  encodeInstructionInto(typesGlobalValues,
125  spirv::Opcode::OpSpecConstantComposite, operands);
126  specConstIDMap[op.getSymName()] = resultID;
127 
128  return processName(resultID, op.getSymName());
129 }
130 
131 LogicalResult Serializer::processSpecConstantCompositeReplicateOp(
132  spirv::EXTSpecConstantCompositeReplicateOp op) {
133  uint32_t typeID = 0;
134  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
135  return failure();
136  }
137 
138  auto constituent = dyn_cast<FlatSymbolRefAttr>(op.getConstituent());
139  if (!constituent)
140  return op.emitError(
141  "expected flat symbol reference for constituent instead of ")
142  << op.getConstituent();
143 
144  StringRef constituentName = constituent.getValue();
145  uint32_t constituentID = getSpecConstID(constituentName);
146  if (!constituentID) {
147  return op.emitError("unknown result <id> for replicated spec constant ")
148  << constituentName;
149  }
150 
151  uint32_t resultID = getNextID();
152  uint32_t operands[] = {typeID, resultID, constituentID};
153 
154  encodeInstructionInto(typesGlobalValues,
155  spirv::Opcode::OpSpecConstantCompositeReplicateEXT,
156  operands);
157 
158  specConstIDMap[op.getSymName()] = resultID;
159 
160  return processName(resultID, op.getSymName());
161 }
162 
163 LogicalResult
164 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
165  uint32_t typeID = 0;
166  if (failed(processType(op.getLoc(), op.getType(), typeID))) {
167  return failure();
168  }
169 
170  auto resultID = getNextID();
171 
172  SmallVector<uint32_t, 8> operands;
173  operands.push_back(typeID);
174  operands.push_back(resultID);
175 
176  Block &block = op.getRegion().getBlocks().front();
177  Operation &enclosedOp = block.getOperations().front();
178 
179  std::string enclosedOpName;
180  llvm::raw_string_ostream rss(enclosedOpName);
181  rss << "Op" << enclosedOp.getName().stripDialect();
182  auto enclosedOpcode = spirv::symbolizeOpcode(enclosedOpName);
183 
184  if (!enclosedOpcode) {
185  op.emitError("Couldn't find op code for op ")
186  << enclosedOp.getName().getStringRef();
187  return failure();
188  }
189 
190  operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
191 
192  // Append operands to the enclosed op to the list of operands.
193  for (Value operand : enclosedOp.getOperands()) {
194  uint32_t id = getValueID(operand);
195  assert(id && "use before def!");
196  operands.push_back(id);
197  }
198 
199  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
200  operands);
201  valueIDMap[op.getResult()] = resultID;
202 
203  return success();
204 }
205 
206 LogicalResult
207 Serializer::processGraphConstantARMOp(spirv::GraphConstantARMOp op) {
208  if (uint32_t resultID = prepareGraphConstantId(op.getLoc(), op.getType(),
209  op.getGraphConstantIdAttr())) {
210  valueIDMap[op.getResult()] = resultID;
211  return success();
212  }
213  return failure();
214 }
215 
216 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
217  auto undefType = op.getType();
218  auto &id = undefValIDMap[undefType];
219  if (!id) {
220  id = getNextID();
221  uint32_t typeID = 0;
222  if (failed(processType(op.getLoc(), undefType, typeID)))
223  return failure();
224  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
225  {typeID, id});
226  }
227  valueIDMap[op.getResult()] = id;
228  return success();
229 }
230 
231 LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
232  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
233  uint32_t argTypeID = 0;
234  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
235  return failure();
236  }
237  auto argValueID = getNextID();
238 
239  // Process decoration attributes of arguments.
240  auto funcOp = cast<FunctionOpInterface>(*op);
241  for (auto argAttr : funcOp.getArgAttrs(idx)) {
242  if (argAttr.getName() != DecorationAttr::name)
243  continue;
244 
245  if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
246  if (failed(processDecorationAttr(op->getLoc(), argValueID,
247  decAttr.getValue(), decAttr)))
248  return failure();
249  }
250  }
251 
252  valueIDMap[arg] = argValueID;
253  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
254  {argTypeID, argValueID});
255  }
256  return success();
257 }
258 
259 LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
260  LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
261  assert(functionHeader.empty() && functionBody.empty());
262 
263  uint32_t fnTypeID = 0;
264  // Generate type of the function.
265  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
266  return failure();
267 
268  // Add the function definition.
269  SmallVector<uint32_t, 4> operands;
270  uint32_t resTypeID = 0;
271  auto resultTypes = op.getFunctionType().getResults();
272  if (resultTypes.size() > 1) {
273  return op.emitError("cannot serialize function with multiple return types");
274  }
275  if (failed(processType(op.getLoc(),
276  (resultTypes.empty() ? getVoidType() : resultTypes[0]),
277  resTypeID))) {
278  return failure();
279  }
280  operands.push_back(resTypeID);
281  auto funcID = getOrCreateFunctionID(op.getName());
282  operands.push_back(funcID);
283  operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
284  operands.push_back(fnTypeID);
285  encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
286 
287  // Add function name.
288  if (failed(processName(funcID, op.getName()))) {
289  return failure();
290  }
291  // Handle external functions with linkage_attributes(LinkageAttributes)
292  // differently.
293  auto linkageAttr = op.getLinkageAttributes();
294  auto hasImportLinkage =
295  linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
296  spirv::LinkageType::Import);
297  if (op.isExternal() && !hasImportLinkage) {
298  return op.emitError(
299  "'spirv.module' cannot contain external functions "
300  "without 'Import' linkage_attributes (LinkageAttributes)");
301  }
302  if (op.isExternal() && hasImportLinkage) {
303  // Add an entry block to set up the block arguments
304  // to match the signature of the function.
305  // This is to generate OpFunctionParameter for functions with
306  // LinkageAttributes.
307  // WARNING: This operation has side-effect, it essentially adds a body
308  // to the func. Hence, making it not external anymore (isExternal()
309  // is going to return false for this function from now on)
310  // Hence, we'll remove the body once we are done with the serialization.
311  op.addEntryBlock();
312  if (failed(processFuncParameter(op)))
313  return failure();
314  // Don't need to process the added block, there is nothing to process,
315  // the fake body was added just to get the arguments, remove the body,
316  // since it's use is done.
317  op.eraseBody();
318  } else {
319  if (failed(processFuncParameter(op)))
320  return failure();
321 
322  // Some instructions (e.g., OpVariable) in a function must be in the first
323  // block in the function. These instructions will be put in
324  // functionHeader. Thus, we put the label in functionHeader first, and
325  // omit it from the first block. OpLabel only needs to be added for
326  // functions with body (including empty body). Since, we added a fake body
327  // for functions with 'Import' Linkage attributes, these functions are
328  // essentially function delcaration, so they should not have OpLabel and a
329  // terminating instruction. That's why we skipped it for those functions.
330  encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
331  {getOrCreateBlockID(&op.front())});
332  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
333  return failure();
335  &op.front(), [&](Block *block) { return processBlock(block); },
336  /*skipHeader=*/true))) {
337  return failure();
338  }
339 
340  // There might be OpPhi instructions who have value references needing to
341  // fix.
342  for (const auto &deferredValue : deferredPhiValues) {
343  Value value = deferredValue.first;
344  uint32_t id = getValueID(value);
345  LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
346  << " to id = " << id << '\n');
347  assert(id && "OpPhi references undefined value!");
348  for (size_t offset : deferredValue.second)
349  functionBody[offset] = id;
350  }
351  deferredPhiValues.clear();
352  }
353  LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
354  << "' --\n");
355  // Insert Decorations based on Function Attributes.
356  // Only attributes we should be considering for decoration are the
357  // ::mlir::spirv::Decoration attributes.
358 
359  for (auto attr : op->getAttrs()) {
360  // Only generate OpDecorate op for spirv::Decoration attributes.
361  auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
362  llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
363  /*capitalizeFirst=*/true));
364  if (isValidDecoration != std::nullopt) {
365  if (failed(processDecoration(op.getLoc(), funcID, attr))) {
366  return failure();
367  }
368  }
369  }
370  // Insert OpFunctionEnd.
371  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
372 
373  functions.append(functionHeader.begin(), functionHeader.end());
374  functions.append(functionBody.begin(), functionBody.end());
375  functionHeader.clear();
376  functionBody.clear();
377 
378  return success();
379 }
380 
381 LogicalResult Serializer::processGraphARMOp(spirv::GraphARMOp op) {
382  if (op.getNumResults() < 1) {
383  return op.emitError("cannot serialize graph with no return types");
384  }
385 
386  LLVM_DEBUG(llvm::dbgs() << "-- start graph '" << op.getName() << "' --\n");
387  assert(functionHeader.empty() && functionBody.empty());
388 
389  uint32_t funcID = getOrCreateFunctionID(op.getName());
390  uint32_t fnTypeID = 0;
391  // Generate type of the function.
392  if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
393  return failure();
394  encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphARM,
395  {fnTypeID, funcID});
396 
397  // Declare the parameters.
398  for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
399  uint32_t argTypeID = 0;
400  SmallVector<uint32_t, 3> inputOperands;
401 
402  if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
403  return failure();
404  }
405 
406  uint32_t argValueID = getNextID();
407  valueIDMap[arg] = argValueID;
408 
409  auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
410  uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
411 
412  inputOperands.push_back(argTypeID);
413  inputOperands.push_back(argValueID);
414  inputOperands.push_back(indexID);
415 
416  encodeInstructionInto(functionHeader, spirv::Opcode::OpGraphInputARM,
417  inputOperands);
418  }
419 
420  if (failed(processBlock(&op.front(), /*omitLabel=*/true)))
421  return failure();
423  &op.front(), [&](Block *block) { return processBlock(block); },
424  /*skipHeader=*/true))) {
425  return failure();
426  }
427 
428  LLVM_DEBUG(llvm::dbgs() << "-- completed graph '" << op.getName()
429  << "' --\n");
430  // Insert OpGraphEndARM.
431  encodeInstructionInto(functionBody, spirv::Opcode::OpGraphEndARM, {});
432 
433  llvm::append_range(graphs, functionHeader);
434  llvm::append_range(graphs, functionBody);
435  functionHeader.clear();
436  functionBody.clear();
437 
438  return success();
439 }
440 
441 LogicalResult
442 Serializer::processGraphEntryPointARMOp(spirv::GraphEntryPointARMOp op) {
443  SmallVector<uint32_t, 4> operands;
444  StringRef graph = op.getFn();
445  // Add the graph <id>.
446  uint32_t graphID = getOrCreateFunctionID(graph);
447  operands.push_back(graphID);
448  // Add the name of the graph.
449  spirv::encodeStringLiteralInto(operands, graph);
450 
451  // Add the interface values.
452  if (ArrayAttr interface = op.getInterface()) {
453  for (Attribute var : interface.getValue()) {
454  StringRef value = cast<FlatSymbolRefAttr>(var).getValue();
455  if (uint32_t id = getVariableID(value)) {
456  operands.push_back(id);
457  } else {
458  return op.emitError(
459  "referencing undefined global variable."
460  "spirv.GraphEntryPointARM is at the end of spirv.module. All "
461  "referenced variables should already be defined");
462  }
463  }
464  }
465  encodeInstructionInto(graphs, spirv::Opcode::OpGraphEntryPointARM, operands);
466  return success();
467 }
468 
469 LogicalResult
470 Serializer::processGraphOutputsARMOp(spirv::GraphOutputsARMOp op) {
471  for (auto [idx, value] : llvm::enumerate(op->getOperands())) {
472  SmallVector<uint32_t, 2> outputOperands;
473 
474  Type resType = value.getType();
475  uint32_t resTypeID = 0;
476  if (failed(processType(op.getLoc(), resType, resTypeID))) {
477  return failure();
478  }
479 
480  uint32_t outputID = getValueID(value);
481  auto attr = IntegerAttr::get(IntegerType::get(op.getContext(), 32), idx);
482  uint32_t indexID = prepareConstantInt(op.getLoc(), attr, false);
483 
484  outputOperands.push_back(outputID);
485  outputOperands.push_back(indexID);
486 
487  encodeInstructionInto(functionBody, spirv::Opcode::OpGraphSetOutputARM,
488  outputOperands);
489  }
490  return success();
491 }
492 
493 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
494  SmallVector<uint32_t, 4> operands;
495  SmallVector<StringRef, 2> elidedAttrs;
496  uint32_t resultID = 0;
497  uint32_t resultTypeID = 0;
498  if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
499  return failure();
500  }
501  operands.push_back(resultTypeID);
502  resultID = getNextID();
503  valueIDMap[op.getResult()] = resultID;
504  operands.push_back(resultID);
505  auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
506  if (attr) {
507  operands.push_back(
508  static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
509  }
510  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
511  for (auto arg : op.getODSOperands(0)) {
512  auto argID = getValueID(arg);
513  if (!argID) {
514  return emitError(op.getLoc(), "operand 0 has a use before def");
515  }
516  operands.push_back(argID);
517  }
518  if (failed(emitDebugLine(functionHeader, op.getLoc())))
519  return failure();
520  encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
521  for (auto attr : op->getAttrs()) {
522  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
523  return attr.getName() == elided;
524  })) {
525  continue;
526  }
527  if (failed(processDecoration(op.getLoc(), resultID, attr))) {
528  return failure();
529  }
530  }
531  return success();
532 }
533 
534 LogicalResult
535 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
536  // Get TypeID.
537  uint32_t resultTypeID = 0;
538  SmallVector<StringRef, 4> elidedAttrs;
539  if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
540  return failure();
541  }
542 
543  elidedAttrs.push_back("type");
544  SmallVector<uint32_t, 4> operands;
545  operands.push_back(resultTypeID);
546  auto resultID = getNextID();
547 
548  // Encode the name.
549  auto varName = varOp.getSymName();
550  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
551  if (failed(processName(resultID, varName))) {
552  return failure();
553  }
554  globalVarIDMap[varName] = resultID;
555  operands.push_back(resultID);
556 
557  // Encode StorageClass.
558  operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
559 
560  // Encode initialization.
561  StringRef initAttrName = varOp.getInitializerAttrName().getValue();
562  if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
563  uint32_t initializerID = 0;
564  auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
566  varOp->getParentOp(), initRef.getAttr());
567 
568  // Check if initializer is GlobalVariable or SpecConstant* cases.
569  if (isa<spirv::GlobalVariableOp>(initOp))
570  initializerID = getVariableID(*initSymbolName);
571  else
572  initializerID = getSpecConstID(*initSymbolName);
573 
574  if (!initializerID)
575  return emitError(varOp.getLoc(),
576  "invalid usage of undefined variable as initializer");
577 
578  operands.push_back(initializerID);
579  elidedAttrs.push_back(initAttrName);
580  }
581 
582  if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
583  return failure();
584  encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
585  elidedAttrs.push_back(initAttrName);
586 
587  // Encode decorations.
588  for (auto attr : varOp->getAttrs()) {
589  if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
590  return attr.getName() == elided;
591  })) {
592  continue;
593  }
594  if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
595  return failure();
596  }
597  }
598  return success();
599 }
600 
601 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
602  // Assign <id>s to all blocks so that branches inside the SelectionOp can
603  // resolve properly.
604  auto &body = selectionOp.getBody();
605  for (Block &block : body)
606  getOrCreateBlockID(&block);
607 
608  auto *headerBlock = selectionOp.getHeaderBlock();
609  auto *mergeBlock = selectionOp.getMergeBlock();
610  auto headerID = getBlockID(headerBlock);
611  auto mergeID = getBlockID(mergeBlock);
612  auto loc = selectionOp.getLoc();
613 
614  // Before we do anything replace results of the selection operation with
615  // values yielded (with `mlir.merge`) from inside the region. The selection op
616  // is being flattened so we do not have to worry about values being defined
617  // inside a region and used outside it anymore.
618  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
619  assert(selectionOp.getNumResults() == mergeOp.getNumOperands());
620  for (unsigned i = 0, e = selectionOp.getNumResults(); i != e; ++i)
621  selectionOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
622 
623  // This SelectionOp is in some MLIR block with preceding and following ops. In
624  // the binary format, it should reside in separate SPIR-V blocks from its
625  // preceding and following ops. So we need to emit unconditional branches to
626  // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal
627  // flow afterwards.
628  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
629 
630  // Emit the selection header block, which dominates all other blocks, first.
631  // We need to emit an OpSelectionMerge instruction before the selection header
632  // block's terminator.
633  auto emitSelectionMerge = [&]() {
634  if (failed(emitDebugLine(functionBody, loc)))
635  return failure();
636  lastProcessedWasMergeInst = true;
638  functionBody, spirv::Opcode::OpSelectionMerge,
639  {mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
640  return success();
641  };
642  if (failed(
643  processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge)))
644  return failure();
645 
646  // Process all blocks with a depth-first visitor starting from the header
647  // block. The selection header block and merge block are skipped by this
648  // visitor.
650  headerBlock, [&](Block *block) { return processBlock(block); },
651  /*skipHeader=*/true, /*skipBlocks=*/{mergeBlock})))
652  return failure();
653 
654  // There is nothing to do for the merge block in the selection, which just
655  // contains a spirv.mlir.merge op, itself. But we need to have an OpLabel
656  // instruction to start a new SPIR-V block for ops following this SelectionOp.
657  // The block should use the <id> for the merge block.
658  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
659 
660  // We do not process the mergeBlock but we still need to generate phi
661  // functions from its block arguments.
662  if (failed(emitPhiForBlockArguments(mergeBlock)))
663  return failure();
664 
665  LLVM_DEBUG(llvm::dbgs() << "done merge ");
666  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
667  LLVM_DEBUG(llvm::dbgs() << "\n");
668  return success();
669 }
670 
671 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
672  // Assign <id>s to all blocks so that branches inside the LoopOp can resolve
673  // properly. We don't need to assign for the entry block, which is just for
674  // satisfying MLIR region's structural requirement.
675  auto &body = loopOp.getBody();
676  for (Block &block : llvm::drop_begin(body))
677  getOrCreateBlockID(&block);
678 
679  auto *headerBlock = loopOp.getHeaderBlock();
680  auto *continueBlock = loopOp.getContinueBlock();
681  auto *mergeBlock = loopOp.getMergeBlock();
682  auto headerID = getBlockID(headerBlock);
683  auto continueID = getBlockID(continueBlock);
684  auto mergeID = getBlockID(mergeBlock);
685  auto loc = loopOp.getLoc();
686 
687  // Before we do anything replace results of the selection operation with
688  // values yielded (with `mlir.merge`) from inside the region.
689  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
690  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
691  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
692  loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
693 
694  // This LoopOp is in some MLIR block with preceding and following ops. In the
695  // binary format, it should reside in separate SPIR-V blocks from its
696  // preceding and following ops. So we need to emit unconditional branches to
697  // jump to this LoopOp's SPIR-V blocks and jumping back to the normal flow
698  // afterwards.
699  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
700 
701  // LoopOp's entry block is just there for satisfying MLIR's structural
702  // requirements so we omit it and start serialization from the loop header
703  // block.
704 
705  // Emit the loop header block, which dominates all other blocks, first. We
706  // need to emit an OpLoopMerge instruction before the loop header block's
707  // terminator.
708  auto emitLoopMerge = [&]() {
709  if (failed(emitDebugLine(functionBody, loc)))
710  return failure();
711  lastProcessedWasMergeInst = true;
713  functionBody, spirv::Opcode::OpLoopMerge,
714  {mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
715  return success();
716  };
717  if (failed(processBlock(headerBlock, /*omitLabel=*/false, emitLoopMerge)))
718  return failure();
719 
720  // Process all blocks with a depth-first visitor starting from the header
721  // block. The loop header block, loop continue block, and loop merge block are
722  // skipped by this visitor and handled later in this function.
724  headerBlock, [&](Block *block) { return processBlock(block); },
725  /*skipHeader=*/true, /*skipBlocks=*/{continueBlock, mergeBlock})))
726  return failure();
727 
728  // We have handled all other blocks. Now get to the loop continue block.
729  if (failed(processBlock(continueBlock)))
730  return failure();
731 
732  // There is nothing to do for the merge block in the loop, which just contains
733  // a spirv.mlir.merge op, itself. But we need to have an OpLabel instruction
734  // to start a new SPIR-V block for ops following this LoopOp. The block should
735  // use the <id> for the merge block.
736  encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
737  LLVM_DEBUG(llvm::dbgs() << "done merge ");
738  LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
739  LLVM_DEBUG(llvm::dbgs() << "\n");
740  return success();
741 }
742 
743 LogicalResult Serializer::processBranchConditionalOp(
744  spirv::BranchConditionalOp condBranchOp) {
745  auto conditionID = getValueID(condBranchOp.getCondition());
746  auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
747  auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
748  SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
749 
750  if (auto weights = condBranchOp.getBranchWeights()) {
751  for (auto val : weights->getValue())
752  arguments.push_back(cast<IntegerAttr>(val).getInt());
753  }
754 
755  if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
756  return failure();
757  encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
758  arguments);
759  return success();
760 }
761 
762 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
763  if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
764  return failure();
765  encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
766  {getOrCreateBlockID(branchOp.getTarget())});
767  return success();
768 }
769 
770 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
771  auto varName = addressOfOp.getVariable();
772  auto variableID = getVariableID(varName);
773  if (!variableID) {
774  return addressOfOp.emitError("unknown result <id> for variable ")
775  << varName;
776  }
777  valueIDMap[addressOfOp.getPointer()] = variableID;
778  return success();
779 }
780 
781 LogicalResult
782 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
783  auto constName = referenceOfOp.getSpecConst();
784  auto constID = getSpecConstID(constName);
785  if (!constID) {
786  return referenceOfOp.emitError(
787  "unknown result <id> for specialization constant ")
788  << constName;
789  }
790  valueIDMap[referenceOfOp.getReference()] = constID;
791  return success();
792 }
793 
794 template <>
795 LogicalResult
796 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
797  SmallVector<uint32_t, 4> operands;
798  // Add the ExecutionModel.
799  operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
800  // Add the function <id>.
801  auto funcID = getFunctionID(op.getFn());
802  if (!funcID) {
803  return op.emitError("missing <id> for function ")
804  << op.getFn()
805  << "; function needs to be defined before spirv.EntryPoint is "
806  "serialized";
807  }
808  operands.push_back(funcID);
809  // Add the name of the function.
810  spirv::encodeStringLiteralInto(operands, op.getFn());
811 
812  // Add the interface values.
813  if (auto interface = op.getInterface()) {
814  for (auto var : interface.getValue()) {
815  auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
816  if (!id) {
817  return op.emitError(
818  "referencing undefined global variable."
819  "spirv.EntryPoint is at the end of spirv.module. All "
820  "referenced variables should already be defined");
821  }
822  operands.push_back(id);
823  }
824  }
825  encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
826  return success();
827 }
828 
829 template <>
830 LogicalResult
831 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
832  SmallVector<uint32_t, 4> operands;
833  // Add the function <id>.
834  auto funcID = getFunctionID(op.getFn());
835  if (!funcID) {
836  return op.emitError("missing <id> for function ")
837  << op.getFn()
838  << "; function needs to be serialized before ExecutionModeOp is "
839  "serialized";
840  }
841  operands.push_back(funcID);
842  // Add the ExecutionMode.
843  operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
844 
845  // Serialize values if any.
846  auto values = op.getValues();
847  if (values) {
848  for (auto &intVal : values.getValue()) {
849  operands.push_back(static_cast<uint32_t>(
850  llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
851  }
852  }
853  encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
854  operands);
855  return success();
856 }
857 
858 template <>
859 LogicalResult
860 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
861  auto funcName = op.getCallee();
862  uint32_t resTypeID = 0;
863 
864  Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
865  if (failed(processType(op.getLoc(), resultTy, resTypeID)))
866  return failure();
867 
868  auto funcID = getOrCreateFunctionID(funcName);
869  auto funcCallID = getNextID();
870  SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
871 
872  for (auto value : op.getArguments()) {
873  auto valueID = getValueID(value);
874  assert(valueID && "cannot find a value for spirv.FunctionCall");
875  operands.push_back(valueID);
876  }
877 
878  if (!isa<NoneType>(resultTy))
879  valueIDMap[op.getResult(0)] = funcCallID;
880 
881  encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
882  return success();
883 }
884 
885 template <>
886 LogicalResult
887 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
888  SmallVector<uint32_t, 4> operands;
889  SmallVector<StringRef, 2> elidedAttrs;
890 
891  for (Value operand : op->getOperands()) {
892  auto id = getValueID(operand);
893  assert(id && "use before def!");
894  operands.push_back(id);
895  }
896 
897  StringAttr memoryAccess = op.getMemoryAccessAttrName();
898  if (auto attr = op->getAttr(memoryAccess)) {
899  operands.push_back(
900  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
901  }
902 
903  elidedAttrs.push_back(memoryAccess.strref());
904 
905  StringAttr alignment = op.getAlignmentAttrName();
906  if (auto attr = op->getAttr(alignment)) {
907  operands.push_back(static_cast<uint32_t>(
908  cast<IntegerAttr>(attr).getValue().getZExtValue()));
909  }
910 
911  elidedAttrs.push_back(alignment.strref());
912 
913  StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
914  if (auto attr = op->getAttr(sourceMemoryAccess)) {
915  operands.push_back(
916  static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
917  }
918 
919  elidedAttrs.push_back(sourceMemoryAccess.strref());
920 
921  StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
922  if (auto attr = op->getAttr(sourceAlignment)) {
923  operands.push_back(static_cast<uint32_t>(
924  cast<IntegerAttr>(attr).getValue().getZExtValue()));
925  }
926 
927  elidedAttrs.push_back(sourceAlignment.strref());
928  if (failed(emitDebugLine(functionBody, op.getLoc())))
929  return failure();
930  encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
931 
932  return success();
933 }
934 template <>
935 LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
936  spirv::GenericCastToPtrExplicitOp op) {
937  SmallVector<uint32_t, 4> operands;
938  Type resultTy;
939  Location loc = op->getLoc();
940  uint32_t resultTypeID = 0;
941  uint32_t resultID = 0;
942  resultTy = op->getResult(0).getType();
943  if (failed(processType(loc, resultTy, resultTypeID)))
944  return failure();
945  operands.push_back(resultTypeID);
946 
947  resultID = getNextID();
948  operands.push_back(resultID);
949  valueIDMap[op->getResult(0)] = resultID;
950 
951  for (Value operand : op->getOperands())
952  operands.push_back(getValueID(operand));
953  spirv::StorageClass resultStorage =
954  cast<spirv::PointerType>(resultTy).getStorageClass();
955  operands.push_back(static_cast<uint32_t>(resultStorage));
956  encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
957  operands);
958  return success();
959 }
960 
961 // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
962 // various Serializer::processOp<...>() specializations.
963 #define GET_SERIALIZATION_FNS
964 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
965 
966 } // namespace spirv
967 } // namespace mlir
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
OpListType & getOperations()
Definition: Block.h:137
Operation & front()
Definition: Block.h:153
A symbol reference with a reference path containing a single element.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector.
Definition: Serializer.cpp:97
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...