26#define GEN_PASS_DEF_TOSAINFERSHAPESPASS 
   27#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" 
   45         isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
 
   56class TypeModificationState {
 
   58  TypeModificationState() = 
default;
 
   60  ~TypeModificationState() {
 
   62    assert(oldTypes.empty() && 
"unhandled type modifications");
 
   66  void setType(Value value, Type type) {
 
   68      oldTypes.emplace_back(value, value.
getType());
 
   76    for (
auto [value, type] : oldTypes)
 
   88    for (
auto [value, oldType] : oldTypes) {
 
   92      llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
 
   94          [](OpOperand &use) -> OpOperand * {
 
  100      tensor::CastOp castValue;
 
  103      for (OpOperand *use : uses) {
 
  104        if (canBeRefined(use->getOwner()))
 
  115              tensor::CastOp::create(builder, value.
getLoc(), oldType, value);
 
  128  llvm::SmallVector<std::pair<Value, Type>> oldTypes;
 
  131void propagateShapesInRegion(
Region ®ion, TypeModificationState &state);
 
  133void propagateShapesToTosaIf(
Operation &op, TypeModificationState &state) {
 
  134  IfOp ifOp = dyn_cast<IfOp>(op);
 
  146      auto oldType = cast<ShapedType>(blockArg.getType());
 
  148      if (inferredTy.hasRank()) {
 
  149        Type newType = oldType.clone(inferredTy.getShape());
 
  150        state.setType(blockArg, newType);
 
  156          ifOp.getOperand(i + 1).getType());
 
  161      if (!joinedKnowledge)
 
  166    propagateShapesInRegion(region, state);
 
  170void propagateShapesToTosaWhile(
Operation &op, TypeModificationState &state) {
 
  171  WhileOp whileOp = dyn_cast<WhileOp>(op);
 
  180  bool hasNewTypes = 
true;
 
  181  while (hasNewTypes) {
 
  182    TypeModificationState localState;
 
  187    for (
int i = 0, s = argTypes.size(); i < s; i++) {
 
  188      localState.setType(block.
getArgument(i), argTypes[i]);
 
  192    propagateShapesInRegion(bodyRegion, localState);
 
  196    for (
auto &block : bodyRegion)
 
  198        yieldOps.push_back(yieldOp);
 
  200    assert(yieldOps.size() == 1 && 
"missing or non-unique yield op");
 
  203    for (
auto ty : argTypes) {
 
  207    for (
auto yieldOp : yieldOps) {
 
  208      for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
 
  211        yieldTypeInfo[it.index()] =
 
  217    if (yieldTypeInfo.size() != argTypes.size()) {
 
  218      op.
emitWarning(
"has a tosa.yield with the incorrect number of operands");
 
  224    for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
 
  225      Type newType = yieldTypeInfo[i].getType();
 
  226      hasNewTypes |= (newType != argTypes[i]);
 
  227      argTypes[i] = newType;
 
  231    localState.rollBack();
 
  238    for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
 
  239      state.setType(region.front().getArgument(i), argTypes[i]);
 
  242    propagateShapesInRegion(region, state);
 
 
  246void propagateShapesInRegion(
Region ®ion, TypeModificationState &state) {
 
  249  for (
auto &block : region) {
 
 
  254      propagateShapesToTosaIf(op, state);
 
  255      propagateShapesToTosaWhile(op, state);
 
 
  257      InferShapedTypeOpInterface shapeInterface =
 
  258          dyn_cast<InferShapedTypeOpInterface>(op);
 
 
  265              .inferReturnTypeComponents(
 
  270        for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
 
 
  277          auto currentKnowledge =
 
  282          inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
 
  283          inferredKnowledge.hasRank = predictedShape.
hasRank();
 
  284          if (predictedShape.
hasRank()) {
 
  285            for (
auto dim : predictedShape.
getDims()) {
 
  286              inferredKnowledge.sizes.push_back(dim);
 
  297          state.setType(
result, newKnowledge.getType());
 
  306void validateSameOperandsAndResultRankTrait(
Region ®ion) {
 
  308  for (
auto &block : region) {
 
  309    for (
auto &op : block) {
 
  319      WhileOp whileOp = dyn_cast<WhileOp>(op);
 
  320      IfOp ifOp = dyn_cast<IfOp>(op);
 
  321      if (whileOp || ifOp) {
 
  324          validateSameOperandsAndResultRankTrait(next);
 
  333struct TosaInferShapes
 
  336  void runOnOperation()
 override {
 
  337    func::FuncOp func = getOperation();
 
  338    TypeModificationState state;
 
  339    propagateShapesInRegion(func.getBody(), state);
 
  342    validateSameOperandsAndResultRankTrait(func.getBody());
 
Block represents an ordered list of Operations.
 
BlockArgument getArgument(unsigned i)
 
unsigned getNumArguments()
 
Operation * getTerminator()
Get the terminator operation of this block.
 
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
 
StringRef getNamespace() const
 
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
 
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
 
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
 
This class verifies that op has same ranks for all operands and results types, if known.
 
Operation is the basic unit of execution within MLIR.
 
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
 
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
 
Value getOperand(unsigned idx)
 
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
 
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
 
Location getLoc()
The source location the operation was defined or derived from.
 
unsigned getNumOperands()
 
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
 
operand_type_range getOperandTypes()
 
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
 
operand_range getOperands()
Returns an iterator on the underlying Value's.
 
result_range getResults()
 
MLIRContext * getContext()
Return the context this operation is associated with.
 
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
 
This class contains a list of basic blocks and a link to the parent operation it is attached to.
 
MLIRContext * getContext()
Return the context this region is inserted in.
 
ShapedTypeComponents that represents the components of a ShapedType.
 
bool hasRank() const
Return whether the shape has a rank.
 
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
 
static TypeID get()
Construct a type info object for the given type T.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
 
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
 
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
 
Type getType() const
Return the type of this value.
 
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
 
Location getLoc() const
Return the location of this value.
 
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
 
LogicalResult verifySameOperandsAndResultRank(Operation *op)
 
Include the generated interface declarations.
 
Statically known information for a particular Value.
 
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
 
static ValueKnowledge getPessimisticValueState()
 
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
 
static ValueKnowledge getKnowledgeFromType(Type type)