26 #define GEN_PASS_DEF_TOSAINFERSHAPESPASS
27 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
45 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
56 class TypeModificationState {
58 TypeModificationState() =
default;
60 ~TypeModificationState() {
62 assert(oldTypes.empty() &&
"unhandled type modifications");
68 oldTypes.emplace_back(value, value.
getType());
76 for (
auto [value, type] : oldTypes)
88 for (
auto [value, oldType] : oldTypes) {
100 tensor::CastOp castValue;
104 if (canBeRefined(use->getOwner()))
115 tensor::CastOp::create(builder, value.
getLoc(), oldType, value);
131 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state);
133 void propagateShapesToTosaIf(
Operation &op, TypeModificationState &state) {
134 IfOp ifOp = dyn_cast<IfOp>(op);
146 auto oldType = cast<ShapedType>(blockArg.getType());
148 if (inferredTy.hasRank()) {
149 Type newType = oldType.clone(inferredTy.getShape());
150 state.setType(blockArg, newType);
156 ifOp.getOperand(i + 1).getType());
161 if (!joinedKnowledge)
166 propagateShapesInRegion(region, state);
170 void propagateShapesToTosaWhile(
Operation &op, TypeModificationState &state) {
171 WhileOp whileOp = dyn_cast<WhileOp>(op);
180 bool hasNewTypes =
true;
181 while (hasNewTypes) {
182 TypeModificationState localState;
187 for (
int i = 0, s = argTypes.size(); i < s; i++) {
188 localState.setType(block.
getArgument(i), argTypes[i]);
192 propagateShapesInRegion(bodyRegion, localState);
196 for (
auto &block : bodyRegion)
198 yieldOps.push_back(yieldOp);
200 assert(yieldOps.size() == 1 &&
"missing or non-unique yield op");
203 for (
auto ty : argTypes) {
207 for (
auto yieldOp : yieldOps) {
211 yieldTypeInfo[it.index()] =
217 if (yieldTypeInfo.size() != argTypes.size()) {
218 op.
emitWarning(
"has a tosa.yield with the incorrect number of operands");
224 for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
225 Type newType = yieldTypeInfo[i].getType();
226 hasNewTypes |= (newType != argTypes[i]);
227 argTypes[i] = newType;
231 localState.rollBack();
238 for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
239 state.setType(region.front().getArgument(i), argTypes[i]);
242 propagateShapesInRegion(region, state);
246 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state) {
249 for (
auto &block : region) {
254 propagateShapesToTosaIf(op, state);
255 propagateShapesToTosaWhile(op, state);
257 InferShapedTypeOpInterface shapeInterface =
258 dyn_cast<InferShapedTypeOpInterface>(op);
265 .inferReturnTypeComponents(
270 for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
271 Value result = std::get<0>(it);
277 auto currentKnowledge =
282 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
283 inferredKnowledge.hasRank = predictedShape.
hasRank();
284 if (predictedShape.
hasRank()) {
285 for (
auto dim : predictedShape.
getDims()) {
286 inferredKnowledge.sizes.push_back(dim);
297 state.setType(result, newKnowledge.getType());
306 void validateSameOperandsAndResultRankTrait(
Region ®ion) {
308 for (
auto &block : region) {
309 for (
auto &op : block) {
319 WhileOp whileOp = dyn_cast<WhileOp>(op);
320 IfOp ifOp = dyn_cast<IfOp>(op);
321 if (whileOp || ifOp) {
324 validateSameOperandsAndResultRankTrait(next);
333 struct TosaInferShapes
334 :
public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
336 void runOnOperation()
override {
337 func::FuncOp func = getOperation();
338 TypeModificationState state;
339 propagateShapesInRegion(func.getBody(), state);
342 validateSameOperandsAndResultRankTrait(func.getBody());
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
StringRef getNamespace() const
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
This class helps build Operations.
This class represents an operand of an operation.
This class verifies that op has same ranks for all operands and results types, if known.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult verifySameOperandsAndResultRank(Operation *op)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
Statically known information for a particular Value.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)