22#define GEN_PASS_DEF_TOSAINPUTSHAPE
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32typedef std::pair<size_t, SmallVector<int64_t>> IdxAndShape;
34FailureOr<IdxAndShape> parseInputShape(
Location loc, StringRef input) {
35 if (!input.consume_front(
"arg")) {
36 emitError(loc) <<
"expected prefix 'arg' at the start of " << input;
40 const size_t colonPos = input.find(
':');
41 if (colonPos == StringRef::npos) {
42 emitError(loc) <<
"expected ':' after argument index in '" << input <<
"'";
46 const StringRef indexStr = input.substr(0, colonPos);
47 input = input.substr(colonPos + 1);
50 if (indexStr.getAsInteger(10,
index) ||
index < 0) {
51 emitError(loc) <<
"invalid argument index, got " << indexStr;
56 while (!input.empty()) {
57 const size_t xPos = input.find(
"x");
59 if (xPos == StringRef::npos) {
63 dimStr = input.substr(0, xPos);
64 input = input.substr(xPos + 1);
68 if (dimStr.getAsInteger(10, dimVal) || dimVal <= 0) {
71 shape.push_back(dimVal);
74 const auto idxAndShape = std::make_pair(
index,
shape);
83FailureOr<SmallVector<IdxAndShape>>
84parseInputShapes(
Location loc,
const std::vector<std::string> &args) {
86 for (
const std::string &arg : args) {
87 const auto maybeInputShape = parseInputShape(loc, arg);
88 if (
failed(maybeInputShape))
90 inputShapes.push_back(maybeInputShape.value());
97 TosaInputShape() =
default;
99 explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
103 void runOnOperation()
override {
105 const Location unknownLoc = UnknownLoc::get(context);
106 const auto maybeArgsParsed = parseInputShapes(unknownLoc, args);
107 if (
failed(maybeArgsParsed))
109 const SmallVector<IdxAndShape> argsParsed = maybeArgsParsed.value();
110 func::FuncOp func = getOperation();
112 const auto getUpdatedTensorType =
113 [&](
size_t argIdx, ArrayRef<Type> argTypes,
114 ArrayRef<int64_t> requestedShape) -> FailureOr<Type> {
115 const size_t numInputs = argTypes.size();
116 if (argIdx >= numInputs)
117 return func.emitError()
118 <<
"provided arg index " << argIdx
119 <<
" is larger than number of inputs " << numInputs <<
".";
121 auto tensorType = dyn_cast<TensorType>(argTypes[argIdx]);
123 return func.emitError()
124 <<
"expected tensor type, got " << argTypes[argIdx];
126 const ArrayRef<int64_t> originalShape = tensorType.getShape();
128 return func.emitError()
130 <<
" has incompatible shape with requested input shape ("
131 << requestedShape <<
"), got " << tensorType;
132 return tensorType.cloneWith(requestedShape, tensorType.getElementType());
136 if (!func.getBody().empty()) {
137 Block &entryBlock = func.getBody().front();
139 for (
const auto &[argIdx, shape] : argsParsed) {
140 FailureOr<Type> newTensorType =
141 getUpdatedTensorType(argIdx, argTypes, shape);
142 if (
failed(newTensorType))
143 return signalPassFailure();
150 const FunctionType oldFunctionType = func.getFunctionType();
151 const ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
152 SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
153 for (
const auto &[argIdx, shape] : argsParsed) {
154 FailureOr<Type> newTensorType =
155 getUpdatedTensorType(argIdx, oldInputTypes, shape);
156 if (
failed(newTensorType))
157 return signalPassFailure();
159 newInputs[argIdx] = newTensorType.value();
163 const auto oldResultTypes = oldFunctionType.getResults();
164 SmallVector<Type> newResults(oldResultTypes.begin(), oldResultTypes.end());
165 if (!func.getBody().empty()) {
166 Block &lastBlock = func.getBody().back();
168 if (
auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
169 const auto returnTypes = returnOp.getOperandTypes();
170 newResults.assign(returnTypes.begin(), returnTypes.end());
174 const FunctionType newFunctionType =
175 oldFunctionType.clone(newInputs, newResults);
176 func.setFunctionType(newFunctionType);
184 return std::make_unique<TosaInputShape>(args);
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
std::unique_ptr< Pass > createTosaInputShapePass(std::vector< std::string > args={})
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.