19 #include "llvm/ADT/DepthFirstIterator.h" 20 #include "llvm/Support/Debug.h" 22 #define DEBUG_TYPE "spirv-serialization" 44 bool skipHeader =
false,
BlockRange skipBlocks = {}) {
45 llvm::df_iterator_default_set<Block *, 4> doneBlocks;
46 doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
48 for (
Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
49 if (skipHeader && block == headerBlock)
51 if (
failed(blockHandler(block)))
59 LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
60 if (
auto resultID = prepareConstant(op.getLoc(), op.getType(), op.value())) {
61 valueIDMap[op.getResult()] = resultID;
67 LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
68 if (
auto resultID = prepareConstantScalar(op.getLoc(), op.default_value(),
71 if (
auto specID = op->getAttrOfType<IntegerAttr>(
"spec_id")) {
72 auto val =
static_cast<uint32_t
>(specID.getInt());
73 if (
failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
77 specConstIDMap[op.sym_name()] = resultID;
78 return processName(resultID, op.sym_name());
84 Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
86 if (
failed(processType(op.getLoc(), op.type(), typeID))) {
90 auto resultID = getNextID();
93 operands.push_back(typeID);
94 operands.push_back(resultID);
96 auto constituents = op.constituents();
98 for (
auto index : llvm::seq<uint32_t>(0, constituents.size())) {
101 auto constituentName = constituent.
getValue();
102 auto constituentID = getSpecConstID(constituentName);
104 if (!constituentID) {
105 return op.emitError(
"unknown result <id> for specialization constant ")
109 operands.push_back(constituentID);
113 spirv::Opcode::OpSpecConstantComposite, operands);
114 specConstIDMap[op.sym_name()] = resultID;
116 return processName(resultID, op.sym_name());
120 Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
122 if (
failed(processType(op.getLoc(), op.getType(), typeID))) {
126 auto resultID = getNextID();
129 operands.push_back(typeID);
130 operands.push_back(resultID);
132 Block &block = op.getRegion().getBlocks().
front();
135 std::string enclosedOpName;
136 llvm::raw_string_ostream rss(enclosedOpName);
138 auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
140 if (!enclosedOpcode) {
141 op.emitError(
"Couldn't find op code for op ")
146 operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
150 uint32_t
id = getValueID(operand);
151 assert(
id &&
"use before def!");
152 operands.push_back(
id);
157 valueIDMap[op.getResult()] = resultID;
162 LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
163 auto undefType = op.getType();
164 auto &
id = undefValIDMap[undefType];
168 if (
failed(processType(op.getLoc(), undefType, typeID)))
173 valueIDMap[op.getResult()] = id;
178 LLVM_DEBUG(llvm::dbgs() <<
"-- start function '" << op.getName() <<
"' --\n");
179 assert(functionHeader.empty() && functionBody.empty());
181 uint32_t fnTypeID = 0;
183 if (
failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
188 uint32_t resTypeID = 0;
189 auto resultTypes = op.getFunctionType().getResults();
190 if (resultTypes.size() > 1) {
191 return op.emitError(
"cannot serialize function with multiple return types");
193 if (
failed(processType(op.getLoc(),
194 (resultTypes.empty() ? getVoidType() : resultTypes[0]),
198 operands.push_back(resTypeID);
199 auto funcID = getOrCreateFunctionID(op.getName());
200 operands.push_back(funcID);
201 operands.push_back(static_cast<uint32_t>(op.function_control()));
202 operands.push_back(fnTypeID);
206 if (
failed(processName(funcID, op.getName()))) {
211 for (
auto arg : op.getArguments()) {
212 uint32_t argTypeID = 0;
213 if (
failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
216 auto argValueID = getNextID();
217 valueIDMap[arg] = argValueID;
219 {argTypeID, argValueID});
223 if (op.isExternal()) {
224 return op.emitError(
"external function is unhandled");
232 {getOrCreateBlockID(&op.front())});
233 if (
failed(processBlock(&op.front(),
true)))
236 &op.front(), [&](
Block *block) {
return processBlock(block); },
242 for (
const auto &deferredValue : deferredPhiValues) {
244 uint32_t
id = getValueID(value);
245 LLVM_DEBUG(llvm::dbgs() <<
"[phi] fix reference of value " << value
246 <<
" to id = " <<
id <<
'\n');
247 assert(
id &&
"OpPhi references undefined value!");
248 for (
size_t offset : deferredValue.second)
249 functionBody[offset] = id;
251 deferredPhiValues.clear();
253 LLVM_DEBUG(llvm::dbgs() <<
"-- completed function '" << op.getName()
258 functions.append(functionHeader.begin(), functionHeader.end());
259 functions.append(functionBody.begin(), functionBody.end());
260 functionHeader.clear();
261 functionBody.clear();
266 LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
269 uint32_t resultID = 0;
270 uint32_t resultTypeID = 0;
271 if (
failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
274 operands.push_back(resultTypeID);
275 resultID = getNextID();
276 valueIDMap[op.getResult()] = resultID;
277 operands.push_back(resultID);
278 auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
280 operands.push_back(static_cast<uint32_t>(
281 attr.cast<IntegerAttr>().getValue().getZExtValue()));
283 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
284 for (
auto arg : op.getODSOperands(0)) {
285 auto argID = getValueID(arg);
287 return emitError(op.getLoc(),
"operand 0 has a use before def");
289 operands.push_back(argID);
291 if (
failed(emitDebugLine(functionHeader, op.getLoc())))
294 for (
auto attr : op->getAttrs()) {
295 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
296 return attr.getName() == elided;
300 if (
failed(processDecoration(op.getLoc(), resultID, attr))) {
308 Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
310 uint32_t resultTypeID = 0;
312 if (
failed(processType(varOp.getLoc(), varOp.type(), resultTypeID))) {
316 elidedAttrs.push_back(
"type");
318 operands.push_back(resultTypeID);
319 auto resultID = getNextID();
322 auto varName = varOp.sym_name();
324 if (
failed(processName(resultID, varName))) {
327 globalVarIDMap[varName] = resultID;
328 operands.push_back(resultID);
331 operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
334 if (
auto initializer = varOp.initializer()) {
335 auto initializerID = getVariableID(*initializer);
336 if (!initializerID) {
338 "invalid usage of undefined variable as initializer");
340 operands.push_back(initializerID);
341 elidedAttrs.push_back(
"initializer");
344 if (
failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
347 elidedAttrs.push_back(
"initializer");
350 for (
auto attr : varOp->getAttrs()) {
351 if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
352 return attr.getName() == elided;
356 if (
failed(processDecoration(varOp.getLoc(), resultID, attr))) {
363 LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
366 auto &body = selectionOp.body();
367 for (
Block &block : body)
368 getOrCreateBlockID(&block);
370 auto *headerBlock = selectionOp.getHeaderBlock();
371 auto *mergeBlock = selectionOp.getMergeBlock();
372 auto headerID = getBlockID(headerBlock);
373 auto mergeID = getBlockID(mergeBlock);
374 auto loc = selectionOp.getLoc();
386 auto emitSelectionMerge = [&]() {
387 if (
failed(emitDebugLine(functionBody, loc)))
389 lastProcessedWasMergeInst =
true;
391 functionBody, spirv::Opcode::OpSelectionMerge,
392 {mergeID,
static_cast<uint32_t
>(selectionOp.selection_control())});
396 processBlock(headerBlock,
false, emitSelectionMerge)))
403 headerBlock, [&](
Block *block) {
return processBlock(block); },
404 true, {mergeBlock})))
412 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
413 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
414 LLVM_DEBUG(llvm::dbgs() <<
"\n");
418 LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
422 auto &body = loopOp.body();
423 for (
Block &block : llvm::make_range(std::next(body.begin(), 1), body.end()))
424 getOrCreateBlockID(&block);
426 auto *headerBlock = loopOp.getHeaderBlock();
427 auto *continueBlock = loopOp.getContinueBlock();
428 auto *mergeBlock = loopOp.getMergeBlock();
429 auto headerID = getBlockID(headerBlock);
430 auto continueID = getBlockID(continueBlock);
431 auto mergeID = getBlockID(mergeBlock);
432 auto loc = loopOp.getLoc();
448 auto emitLoopMerge = [&]() {
449 if (
failed(emitDebugLine(functionBody, loc)))
451 lastProcessedWasMergeInst =
true;
453 functionBody, spirv::Opcode::OpLoopMerge,
454 {mergeID, continueID,
static_cast<uint32_t
>(loopOp.loop_control())});
457 if (
failed(processBlock(headerBlock,
false, emitLoopMerge)))
464 headerBlock, [&](
Block *block) {
return processBlock(block); },
465 true, {continueBlock, mergeBlock})))
469 if (
failed(processBlock(continueBlock)))
477 LLVM_DEBUG(llvm::dbgs() <<
"done merge ");
478 LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
479 LLVM_DEBUG(llvm::dbgs() <<
"\n");
484 spirv::BranchConditionalOp condBranchOp) {
485 auto conditionID = getValueID(condBranchOp.condition());
486 auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
487 auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
490 if (
auto weights = condBranchOp.branch_weights()) {
491 for (
auto val : weights->getValue())
492 arguments.push_back(val.cast<IntegerAttr>().getInt());
495 if (
failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
502 LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
503 if (
failed(emitDebugLine(functionBody, branchOp.getLoc())))
506 {getOrCreateBlockID(branchOp.getTarget())});
510 LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
511 auto varName = addressOfOp.variable();
512 auto variableID = getVariableID(varName);
514 return addressOfOp.emitError(
"unknown result <id> for variable ")
517 valueIDMap[addressOfOp.pointer()] = variableID;
522 Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
523 auto constName = referenceOfOp.spec_const();
524 auto constID = getSpecConstID(constName);
526 return referenceOfOp.emitError(
527 "unknown result <id> for specialization constant ")
530 valueIDMap[referenceOfOp.reference()] = constID;
536 Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
539 operands.push_back(static_cast<uint32_t>(op.execution_model()));
541 auto funcID = getFunctionID(op.fn());
543 return op.emitError(
"missing <id> for function ")
545 <<
"; function needs to be defined before spv.EntryPoint is " 548 operands.push_back(funcID);
553 if (
auto interface = op.interface()) {
554 for (
auto var : interface.getValue()) {
557 return op.emitError(
"referencing undefined global variable." 558 "spv.EntryPoint is at the end of spv.module. All " 559 "referenced variables should already be defined");
561 operands.push_back(
id);
570 Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
571 StringRef argNames[] = {
"execution_scope",
"memory_scope",
575 for (
auto argName : argNames) {
576 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
577 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
581 operands.push_back(operand);
591 Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
594 auto funcID = getFunctionID(op.fn());
596 return op.emitError(
"missing <id> for function ")
598 <<
"; function needs to be serialized before ExecutionModeOp is " 601 operands.push_back(funcID);
603 operands.push_back(static_cast<uint32_t>(op.execution_mode()));
606 auto values = op.values();
608 for (
auto &intVal : values.getValue()) {
609 operands.push_back(static_cast<uint32_t>(
610 intVal.cast<IntegerAttr>().getValue().getZExtValue()));
620 Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
621 StringRef argNames[] = {
"memory_scope",
"memory_semantics"};
624 for (
auto argName : argNames) {
625 auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
626 auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
630 operands.push_back(operand);
639 Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
640 auto funcName = op.callee();
641 uint32_t resTypeID = 0;
643 Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
644 if (
failed(processType(op.getLoc(), resultTy, resTypeID)))
647 auto funcID = getOrCreateFunctionID(funcName);
648 auto funcCallID = getNextID();
651 for (
auto value : op.arguments()) {
652 auto valueID = getValueID(
value);
653 assert(valueID &&
"cannot find a value for spv.FunctionCall");
654 operands.push_back(valueID);
657 if (!resultTy.
isa<NoneType>())
658 valueIDMap[op.getResult(0)] = funcCallID;
666 Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
670 for (
Value operand : op->getOperands()) {
671 auto id = getValueID(operand);
672 assert(
id &&
"use before def!");
673 operands.push_back(
id);
676 if (
auto attr = op->getAttr(
"memory_access")) {
677 operands.push_back(static_cast<uint32_t>(
678 attr.cast<IntegerAttr>().getValue().getZExtValue()));
681 elidedAttrs.push_back(
"memory_access");
683 if (
auto attr = op->getAttr(
"alignment")) {
684 operands.push_back(static_cast<uint32_t>(
685 attr.cast<IntegerAttr>().getValue().getZExtValue()));
688 elidedAttrs.push_back(
"alignment");
690 if (
auto attr = op->getAttr(
"source_memory_access")) {
691 operands.push_back(static_cast<uint32_t>(
692 attr.cast<IntegerAttr>().getValue().getZExtValue()));
695 elidedAttrs.push_back(
"source_memory_access");
697 if (
auto attr = op->getAttr(
"source_alignment")) {
698 operands.push_back(static_cast<uint32_t>(
699 attr.cast<IntegerAttr>().getValue().getZExtValue()));
702 elidedAttrs.push_back(
"source_alignment");
703 if (
failed(emitDebugLine(functionBody, op.getLoc())))
712 #define GET_SERIALIZATION_FNS 713 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
StringRef stripDialect() const
Return the operation name with dialect name stripped, if it has one.
Operation is a basic unit of execution within MLIR.
operand_range getOperands()
Returns an iterator on the underlying Value's.
Block represents an ordered list of Operations.
A symbol reference with a reference path containing a single element.
static LogicalResult visitInPrettyBlockOrder(Block *headerBlock, function_ref< LogicalResult(Block *)> blockHandler, bool skipHeader=false, BlockRange skipBlocks={})
A pre-order depth-first visitor function for processing basic blocks.
OpListType & getOperations()
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
void encodeInstructionInto(SmallVectorImpl< uint32_t > &binary, spirv::Opcode op, ArrayRef< uint32_t > operands)
Encodes an SPIR-V instruction with the given opcode and operands into the given binary vector...
static constexpr const bool value
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
StringRef getValue() const
Returns the name of the held symbol reference.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
This class provides an abstraction over the different types of ranges over Blocks.
void encodeStringLiteralInto(SmallVectorImpl< uint32_t > &binary, StringRef literal)
Encodes an SPIR-V literal string into the given binary vector.
OperationName getName()
The name of an operation is the key identifier for it.