22#define GEN_PASS_DEF_TOSAINPUTSHAPE
23#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
32typedef std::pair<size_t, SmallVector<int64_t>> IdxAndShape;
34FailureOr<IdxAndShape> parseInputShape(
Location loc, StringRef input) {
35 if (!input.consume_front(
"arg")) {
36 emitError(loc) <<
"expected prefix 'arg' at the start of " << input;
40 const size_t colonPos = input.find(
':');
41 if (colonPos == StringRef::npos) {
42 emitError(loc) <<
"expected ':' after argument index in '" << input <<
"'";
46 const StringRef indexStr = input.substr(0, colonPos);
47 input = input.substr(colonPos + 1);
50 if (indexStr.getAsInteger(10,
index) ||
index < 0) {
51 emitError(loc) <<
"invalid argument index, got " << indexStr;
56 while (!input.empty()) {
57 const size_t xPos = input.find(
"x");
59 if (xPos == StringRef::npos) {
63 dimStr = input.substr(0, xPos);
64 input = input.substr(xPos + 1);
68 if (dimStr.getAsInteger(10, dimVal) || dimVal <= 0) {
71 shape.push_back(dimVal);
74 const auto idxAndShape = std::make_pair(
index,
shape);
83FailureOr<SmallVector<IdxAndShape>>
84parseInputShapes(
Location loc,
const std::vector<std::string> &args) {
86 for (
const std::string &arg : args) {
87 const auto maybeInputShape = parseInputShape(loc, arg);
88 if (
failed(maybeInputShape))
90 inputShapes.push_back(maybeInputShape.value());
97 TosaInputShape() =
default;
99 explicit TosaInputShape(std::vector<std::string> args) : TosaInputShape() {
103 void runOnOperation()
override {
105 const Location unknownLoc = UnknownLoc::get(context);
106 const auto maybeArgsParsed = parseInputShapes(unknownLoc, args);
107 if (
failed(maybeArgsParsed))
109 const SmallVector<IdxAndShape> argsParsed = maybeArgsParsed.value();
110 func::FuncOp func = getOperation();
112 const auto getUpdatedTensorType =
113 [&](
size_t argIdx, ArrayRef<Type> argTypes,
114 ArrayRef<int64_t> requestedShape) -> FailureOr<Type> {
115 const size_t numInputs = argTypes.size();
116 if (argIdx >= numInputs)
117 return func.emitError()
118 <<
"provided arg index " << argIdx
119 <<
" is larger than number of inputs " << numInputs <<
".";
121 auto tensorType = dyn_cast<TensorType>(argTypes[argIdx]);
123 return func.emitError()
124 <<
"expected tensor type, got " << argTypes[argIdx];
126 const ArrayRef<int64_t> originalShape = tensorType.getShape();
128 return func.emitError()
130 <<
" has incompatible shape with requested input shape ("
131 << requestedShape <<
"), got " << tensorType;
132 return tensorType.cloneWith(requestedShape, tensorType.getElementType());
136 Block &entryBlock = func.getBody().front();
138 for (
const auto &[argIdx, shape] : argsParsed) {
139 FailureOr<Type> newTensorType =
140 getUpdatedTensorType(argIdx, argTypes, shape);
141 if (
failed(newTensorType))
142 return signalPassFailure();
148 const FunctionType oldFunctionType = func.getFunctionType();
149 const ArrayRef<Type> oldInputTypes = oldFunctionType.getInputs();
150 SmallVector<Type> newInputs(oldInputTypes.begin(), oldInputTypes.end());
151 for (
const auto &[argIdx, shape] : argsParsed) {
152 FailureOr<Type> newTensorType =
153 getUpdatedTensorType(argIdx, oldInputTypes, shape);
154 if (
failed(newTensorType))
155 return signalPassFailure();
157 newInputs[argIdx] = newTensorType.value();
161 Block &lastBlock = func.getBody().back();
163 SmallVector<Type> newResults;
164 if (
auto returnOp = dyn_cast_or_null<func::ReturnOp>(terminator)) {
165 const auto types = returnOp.getOperandTypes();
166 newResults.assign(types.begin(), types.end());
168 const auto types = oldFunctionType.getResults();
169 newResults.assign(types.begin(), types.end());
171 const FunctionType newFunctionType =
172 oldFunctionType.
clone(newInputs, newResults);
173 func.setFunctionType(newFunctionType);
181 return std::make_unique<TosaInputShape>(args);
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
std::unique_ptr< Pass > createTosaInputShapePass(std::vector< std::string > args={})
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.