MLIR 23.0.0git
TosaInputShape.cpp
Go to the documentation of this file.
1//===- TosaInputShape.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Pass that overrides the dynamic input shapes of function arguments to
10// specified static shapes. If a specified static shape conflicts with the
11// static dimensions in an original input shape, an error is reported.
12//
13//===----------------------------------------------------------------------===//
14
18#include "mlir/Pass/Pass.h"
19
20namespace mlir {
21namespace tosa {
22#define GEN_PASS_DEF_TOSAINPUTSHAPE
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
24} // namespace tosa
25} // namespace mlir
26
27using namespace mlir;
28using namespace mlir::tosa;
29
30namespace {
31
32typedef std::pair<size_t, SmallVector<int64_t>> IdxAndShape;
33
34FailureOr<IdxAndShape> parseInputShape(Location loc, StringRef input) {
35 if (!input.consume_front("arg")) {
36 emitError(loc) << "expected prefix 'arg' at the start of " << input;
37 return failure();
38 }
39
40 const size_t colonPos = input.find(':');
41 if (colonPos == StringRef::npos) {
42 emitError(loc) << "expected ':' after argument index in '" << input << "'";
43 return failure();
44 }
45
46 const StringRef indexStr = input.substr(0, colonPos);
47 input = input.substr(colonPos + 1);
48
49 size_t index;
50 if (indexStr.getAsInteger(10, index) || index < 0) {
51 emitError(loc) << "invalid argument index, got " << indexStr;
52 return failure();
53 }
54
56 while (!input.empty()) {
57 const size_t xPos = input.find("x");
58 StringRef dimStr;
59 if (xPos == StringRef::npos) {
60 dimStr = input;
61 input = "";
62 } else {
63 dimStr = input.substr(0, xPos);
64 input = input.substr(xPos + 1);
65 }
66
67 int64_t dimVal;
68 if (dimStr.getAsInteger(10, dimVal) || dimVal <= 0) {
69 return failure();
70 }
71 shape.push_back(dimVal);
72 }
73
74 const auto idxAndShape = std::make_pair(index, shape);
75 return {idxAndShape};
76}
77
78// Parse input shape arguments from command line input. Returns parsed
79// static shapes and an optional error message.
80// For example:
81// "args=arg0:5x10,arg8:3x9" => {{{0, {5, 10}}, {8, {3, 9}}}, ""}
82// "args=arg0:" => {{}, "error message"}
83FailureOr<SmallVector<IdxAndShape>>
84parseInputShapes(Location loc, const std::vector<std::string> &args) {
85 SmallVector<IdxAndShape> inputShapes;
86 for (const std::string &arg : args) {
87 const auto maybeInputShape = parseInputShape(loc, arg);
88 if (failed(maybeInputShape))
89 return failure();
90 inputShapes.push_back(maybeInputShape.value());
91 }
92 return inputShapes;
93}
94
95struct TosaInputShape : public tosa::impl::TosaInputShapeBase<TosaInputShape> {
96public:
97 TosaInputShape() = default;
98
99 explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
100 this->args = args;
101 }
102
103 void runOnOperation() override {
104 MLIRContext *context = &getContext();
105 const Location unknownLoc = UnknownLoc::get(context);
106 const auto maybeArgsParsed = parseInputShapes(unknownLoc, args);
107 if (failed(maybeArgsParsed))
108 return;
109 const SmallVector<IdxAndShape> argsParsed = maybeArgsParsed.value();
110 func::FuncOp func = getOperation();
111
112 const auto getUpdatedTensorType =
113 [&](size_t argIdx, ArrayRef<Type> argTypes,
114 ArrayRef<int64_t> requestedShape) -> FailureOr<Type> {
115 const size_t numInputs = argTypes.size();
116 if (argIdx >= numInputs)
117 return func.emitError()
118 << "provided arg index " << argIdx
119 << " is larger than number of inputs " << numInputs << ".";
120
121 auto tensorType = dyn_cast<TensorType>(argTypes[argIdx]);
122 if (!tensorType)
123 return func.emitError()
124 << "expected tensor type, got " << argTypes[argIdx];
125
126 const ArrayRef<int64_t> originalShape = tensorType.getShape();
127 if (failed(verifyCompatibleShape(originalShape, requestedShape)))
128 return func.emitError()
129 << "arg" << argIdx
130 << " has incompatible shape with requested input shape ("
131 << requestedShape << "), got " << tensorType;
132 return tensorType.cloneWith(requestedShape, tensorType.getElementType());
133 };
134
135 // Update argument shapes in the entry block
136 Block &entryBlock = func.getBody().front();
137 const SmallVector<Type> argTypes(entryBlock.getArgumentTypes());
138 for (const auto &[argIdx, shape] : argsParsed) {
139 FailureOr<Type> newTensorType =
140 getUpdatedTensorType(argIdx, argTypes, shape);
141 if (failed(newTensorType))
142 return signalPassFailure();
143
144 entryBlock.getArgument(argIdx).setType(newTensorType.value());
145 }
146
147 // Get new func argument types
148 const FunctionType oldFunctionType = func.getFunctionType();
149 const ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
150 SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
151 for (const auto &[argIdx, shape] : argsParsed) {
152 FailureOr<Type> newTensorType =
153 getUpdatedTensorType(argIdx, oldInputTypes, shape);
154 if (failed(newTensorType))
155 return signalPassFailure();
156
157 newInputs[argIdx] = newTensorType.value();
158 }
159
160 // Update function signature
161 Block &lastBlock = func.getBody().back();
162 const Operation *terminator = lastBlock.getTerminator();
163 SmallVector<Type> newResults;
164 if (auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
165 const auto types = returnOp.getOperandTypes();
166 newResults.assign(types.begin(), types.end());
167 } else {
168 const auto types = oldFunctionType.getResults();
169 newResults.assign(types.begin(), types.end());
170 }
171 const FunctionType newFunctionType =
172 oldFunctionType.clone(newInputs, newResults);
173 func.setFunctionType(newFunctionType);
174 }
175};
176
177} // namespace
178
179std::unique_ptr<Pass>
180mlir::tosa::createTosaInputShapePass(std::vector<std::string> args) {
181 return std::make_unique<TosaInputShape>(args);
182}
b getContext())
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Definition Value.h:116
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
std::unique_ptr< Pass > createTosaInputShapePass(std::vector< std::string > args={})
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.