MLIR 23.0.0git
TosaInputShape.cpp
Go to the documentation of this file.
1//===- TosaInputShape.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pass that overrides the dynamic input shapes of function arguments to
10// specified static shapes. If a specified static shape conflicts with the
11// static dimensions in an original input shape, an error is reported.
12//
13//===----------------------------------------------------------------------===//
14
18#include "mlir/Pass/Pass.h"
19
20namespace mlir {
21namespace tosa {
22#define GEN_PASS_DEF_TOSAINPUTSHAPE
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
24} // namespace tosa
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::tosa;
29
30namespace {
31
32typedef std::pair<size_t, SmallVector<int64_t>> IdxAndShape;
33
34FailureOr<IdxAndShape> parseInputShape(Location loc, StringRef input) {
35 if (!input.consume_front("arg")) {
36 emitError(loc) << "expected prefix 'arg' at the start of " << input;
37 return failure();
38 }
39
40 const size_t colonPos = input.find(':');
41 if (colonPos == StringRef::npos) {
42 emitError(loc) << "expected ':' after argument index in '" << input << "'";
43 return failure();
44 }
45
46 const StringRef indexStr = input.substr(0, colonPos);
47 input = input.substr(colonPos + 1);
48
49 size_t index;
50 if (indexStr.getAsInteger(10, index) || index < 0) {
51 emitError(loc) << "invalid argument index, got " << indexStr;
52 return failure();
53 }
54
56 while (!input.empty()) {
57 const size_t xPos = input.find("x");
58 StringRef dimStr;
59 if (xPos == StringRef::npos) {
60 dimStr = input;
61 input = "";
62 } else {
63 dimStr = input.substr(0, xPos);
64 input = input.substr(xPos + 1);
65 }
66
67 int64_t dimVal;
68 if (dimStr.getAsInteger(10, dimVal) || dimVal <= 0) {
69 return failure();
70 }
71 shape.push_back(dimVal);
72 }
73
74 const auto idxAndShape = std::make_pair(index, shape);
75 return {idxAndShape};
76}
77
78// Parse input shape arguments from command line input. Returns parsed
79// static shapes and an optional error message.
80// For example:
81// "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
82// "args=arg0:" => {{}, "error message"}
83FailureOr<SmallVector<IdxAndShape>>
84parseInputShapes(Location loc, const std::vector<std::string> &args) {
85 SmallVector<IdxAndShape> inputShapes;
86 for (const std::string &arg : args) {
87 const auto maybeInputShape = parseInputShape(loc, arg);
88 if (failed(maybeInputShape))
89 return failure();
90 inputShapes.push_back(maybeInputShape.value());
91 }
92 return inputShapes;
93}
94
95struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
96public:
97 TosaInputShape() = default;
98
99 explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
100 this->args = args;
101 }
102
103 void runOnOperation() override {
104 MLIRContext *context = &getContext();
105 const Location unknownLoc = UnknownLoc::get(context);
106 const auto maybeArgsParsed = parseInputShapes(unknownLoc, args);
107 if (failed(maybeArgsParsed))
108 return;
109 const SmallVector<IdxAndShape> argsParsed = maybeArgsParsed.value();
110 func::FuncOp func = getOperation();
111
112 const auto getUpdatedTensorType =
113 [&](size_t argIdx, ArrayRef<Type> argTypes,
114 ArrayRef<int64_t> requestedShape) -> FailureOr<Type> {
115 const size_t numInputs = argTypes.size();
116 if (argIdx >= numInputs)
117 return func.emitError()
118 << "provided arg index " << argIdx
119 << " is larger than number of inputs " << numInputs << ".";
120
121 auto tensorType = dyn_cast<TensorType>(argTypes[argIdx]);
122 if (!tensorType)
123 return func.emitError()
124 << "expected tensor type, got " << argTypes[argIdx];
125
126 const ArrayRef<int64_t> originalShape = tensorType.getShape();
127 if (failed(verifyCompatibleShape(originalShape, requestedShape)))
128 return func.emitError()
129 << "arg" << argIdx
130 << " has incompatible shape with requested input shape ("
131 << requestedShape << "), got " << tensorType;
132 return tensorType.cloneWith(requestedShape, tensorType.getElementType());
133 };
134
135 // Update argument shapes in the entry block if the function has body.
136 if (!func.getBody().empty()) {
137 Block &entryBlock = func.getBody().front();
138 const SmallVector<Type> argTypes(entryBlock.getArgumentTypes());
139 for (const auto &[argIdx, shape] : argsParsed) {
140 FailureOr<Type> newTensorType =
141 getUpdatedTensorType(argIdx, argTypes, shape);
142 if (failed(newTensorType))
143 return signalPassFailure();
144
145 entryBlock.getArgument(argIdx).setType(newTensorType.value());
146 }
147 }
148
149 // Get new func argument types
150 const FunctionType oldFunctionType = func.getFunctionType();
151 const ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
152 SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
153 for (const auto &[argIdx, shape] : argsParsed) {
154 FailureOr<Type> newTensorType =
155 getUpdatedTensorType(argIdx, oldInputTypes, shape);
156 if (failed(newTensorType))
157 return signalPassFailure();
158
159 newInputs[argIdx] = newTensorType.value();
160 }
161
162 // Update function signature
163 const auto oldResultTypes = oldFunctionType.getResults();
164 SmallVector<Type> newResults(oldResultTypes.begin(), oldResultTypes.end());
165 if (!func.getBody().empty()) {
166 Block &lastBlock = func.getBody().back();
167 const Operation *terminator = lastBlock.getTerminator();
168 if (auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
169 const auto returnTypes = returnOp.getOperandTypes();
170 newResults.assign(returnTypes.begin(), returnTypes.end());
171 }
172 }
173
174 const FunctionType newFunctionType =
175 oldFunctionType.clone(newInputs, newResults);
176 func.setFunctionType(newFunctionType);
177 }
178};
179
180} // namespace
181
182std::unique_ptr<Pass>
183mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
184 return std::make_unique<TosaInputShape>(args);
185}
b getContext())
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
std::unique_ptr< Pass > createTosaInputShapePass(std::vector< std::string > args={})
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.