28 #define GEN_PASS_DEF_TOSAINFERSHAPES
29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
58 class TypeModificationState {
60 TypeModificationState() =
default;
62 ~TypeModificationState() {
64 assert(oldTypes.empty() &&
"unhandled type modifications");
70 oldTypes.emplace_back(value, value.
getType());
78 for (
auto [value, type] : oldTypes)
90 for (
auto [value, oldType] : oldTypes) {
102 tensor::CastOp castValue;
106 if (canBeRefined(use->getOwner()))
117 builder.create<tensor::CastOp>(value.
getLoc(), oldType, value);
133 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state);
135 void propagateShapesToTosaIf(
Operation &op, TypeModificationState &state) {
136 IfOp ifOp = dyn_cast<IfOp>(op);
148 auto oldType = cast<ShapedType>(blockArg.getType());
150 if (inferredTy.hasRank()) {
151 Type newType = oldType.clone(inferredTy.getShape());
152 state.setType(blockArg, newType);
158 ifOp.getOperand(i + 1).getType());
163 if (!joinedKnowledge)
168 propagateShapesInRegion(region, state);
172 void propagateShapesToTosaWhile(
Operation &op, TypeModificationState &state) {
173 WhileOp whileOp = dyn_cast<WhileOp>(op);
182 bool hasNewTypes =
true;
183 while (hasNewTypes) {
184 TypeModificationState localState;
189 for (
int i = 0, s = argTypes.size(); i < s; i++) {
190 localState.setType(block.
getArgument(i), argTypes[i]);
194 propagateShapesInRegion(bodyRegion, localState);
198 for (
auto &block : bodyRegion)
200 yieldOps.push_back(yieldOp);
202 assert(yieldOps.size() == 1 &&
"missing or non-unique yield op");
205 for (
auto ty : argTypes) {
209 for (
auto yieldOp : yieldOps) {
213 yieldTypeInfo[it.index()] =
219 if (yieldTypeInfo.size() != argTypes.size()) {
220 op.
emitWarning(
"has a tosa.yield with the incorrect number of operands");
226 for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
227 Type newType = yieldTypeInfo[i].getType();
228 hasNewTypes |= (newType != argTypes[i]);
229 argTypes[i] = newType;
233 localState.rollBack();
240 for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
241 state.setType(region.front().getArgument(i), argTypes[i]);
244 propagateShapesInRegion(region, state);
248 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state) {
251 for (
auto &block : region) {
256 propagateShapesToTosaIf(op, state);
257 propagateShapesToTosaWhile(op, state);
259 InferShapedTypeOpInterface shapeInterface =
260 dyn_cast<InferShapedTypeOpInterface>(op);
267 .inferReturnTypeComponents(
272 for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
273 Value result = std::get<0>(it);
279 auto currentKnowledge =
284 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
285 inferredKnowledge.hasRank = predictedShape.
hasRank();
286 if (predictedShape.
hasRank()) {
287 for (
auto dim : predictedShape.
getDims()) {
288 inferredKnowledge.sizes.push_back(dim);
299 state.setType(result, newKnowledge.getType());
308 struct TosaInferShapes
309 :
public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
311 void runOnOperation()
override {
312 func::FuncOp func = getOperation();
313 TypeModificationState state;
314 propagateShapesInRegion(func.getBody(), state);
321 return std::make_unique<TosaInferShapes>();
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)