MLIR 23.0.0git
TosaToSPIRVTosaConstants.cpp
Go to the documentation of this file.
1//===- TosaToSPIRVTosaConstants.cpp - TOSA graph constants ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements preprocessing that marks TOSA constants that should be
10// lowered to SPIR-V Graph constants.
11//
12//===----------------------------------------------------------------------===//
13
18#include <optional>
19
20namespace mlir {
21#define GEN_PASS_DEF_TOSATOSPIRVTOSAMARKGRAPHCONSTANTS
22#include "mlir/Conversion/Passes.h.inc"
23
24namespace tosa {
25namespace {
26
27constexpr uint32_t maxInlineConstElements = 16;
28constexpr uint32_t maxInlineConstShapeElements = 32;
29
30std::optional<ElementsAttr> getConstantValues(Operation *op) {
31 if (auto constOp = dyn_cast<tosa::ConstOp>(op))
32 return constOp.getValuesAttr();
33 if (auto constShapeOp = dyn_cast<tosa::ConstShapeOp>(op))
34 return constShapeOp.getValuesAttr();
35 return std::nullopt;
36}
37
38bool shouldMarkGraphConstant(Operation *op) {
39 if (op->use_empty())
40 return false;
41
42 std::optional<ElementsAttr> values = getConstantValues(op);
43 if (!values)
44 return false;
45
46 uint32_t maxInlineElements = isa<tosa::ConstOp>(op)
47 ? maxInlineConstElements
48 : maxInlineConstShapeElements;
49 return values->size() > maxInlineElements;
50}
51
52void setGraphConstantId(Operation *op, uint32_t id) {
53 auto i32Type = IntegerType::get(op->getContext(), 32);
54 op->setAttr(graphARMGraphConstantIdAttrName, IntegerAttr::get(i32Type, id));
55}
56
57struct TosaToSPIRVTosaMarkGraphConstants final
58 : impl::TosaToSPIRVTosaMarkGraphConstantsBase<
59 TosaToSPIRVTosaMarkGraphConstants> {
60 void runOnOperation() override {
61 uint32_t nextConstantId = 0;
62 WalkResult result =
63 getOperation().walk([&](Operation *op) {
64 if (!isa<tosa::ConstOp, tosa::ConstShapeOp>(op))
65 return WalkResult::advance();
66
67 if (op->hasAttr(graphARMGraphConstantIdAttrName)) {
68 op->emitOpError()
69 << "already has `" << graphARMGraphConstantIdAttrName
70 << "`; this pass assigns graph constant IDs automatically and "
71 "does not support pre-marked constants";
72 return WalkResult::interrupt();
73 }
74
75 if (shouldMarkGraphConstant(op))
76 setGraphConstantId(op, nextConstantId++);
77 return WalkResult::advance();
78 });
79
80 if (result.wasInterrupted())
81 signalPassFailure();
82 }
83};
84
85} // namespace
86
88 return std::make_unique<TosaToSPIRVTosaMarkGraphConstants>();
89}
90
91} // namespace tosa
92} // namespace mlir
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
std::unique_ptr< Pass > createTosaToSPIRVTosaMarkGraphConstants()
constexpr llvm::StringLiteral graphARMGraphConstantIdAttrName
Include the generated interface declarations.