27 #define GEN_PASS_DEF_TOSAINFERSHAPES
28 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
48 TosaDialect::getDialectNamespace() ||
49 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
60 class TypeModificationState {
62 TypeModificationState() =
default;
64 ~TypeModificationState() {
66 assert(oldTypes.empty() &&
"unhandled type modifications");
72 oldTypes.emplace_back(value, value.
getType());
81 for (
auto [value, type] : oldTypes)
93 for (
auto [value, oldType] : oldTypes) {
94 for (
auto &use : value.
getUses()) {
95 if (isReplaceableUser(use.getOwner()))
99 builder.setInsertionPoint(use.getOwner());
102 use.set(builder.create<tensor::CastOp>(loc, oldType, value));
115 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state);
117 void propagateShapesToTosaIf(
Operation &op, TypeModificationState &state) {
118 IfOp ifOp = dyn_cast<IfOp>(op);
130 auto oldType = cast<ShapedType>(blockArg.getType());
132 if (inferredTy.hasRank()) {
133 Type newType = oldType.clone(inferredTy.getShape());
134 state.setType(blockArg, newType);
140 ifOp.getOperand(i + 1).getType());
145 if (!joinedKnowledge)
150 propagateShapesInRegion(region, state);
154 void propagateShapesToTosaWhile(
Operation &op, TypeModificationState &state) {
155 WhileOp whileOp = dyn_cast<WhileOp>(op);
164 bool hasNewTypes =
true;
165 while (hasNewTypes) {
166 TypeModificationState localState;
171 for (
int i = 0, s = argTypes.size(); i < s; i++) {
172 localState.setType(block.
getArgument(i), argTypes[i]);
176 propagateShapesInRegion(bodyRegion, localState);
180 for (
auto &block : bodyRegion)
182 yieldOps.push_back(yieldOp);
184 assert(yieldOps.size() == 1 &&
"missing or non-unique yield op");
187 for (
auto ty : argTypes) {
191 for (
auto yieldOp : yieldOps) {
195 yieldTypeInfo[it.index()] =
201 if (yieldTypeInfo.size() != argTypes.size()) {
202 op.
emitWarning(
"has a tosa.yield with the incorrect number of operands");
208 for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
209 Type newType = yieldTypeInfo[i].getType();
210 hasNewTypes |= (newType != argTypes[i]);
211 argTypes[i] = newType;
222 for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
223 state.setType(region.front().getArgument(i), argTypes[i]);
226 propagateShapesInRegion(region, state);
230 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state) {
231 for (
auto &block : region) {
237 propagateShapesToTosaIf(op, state);
238 propagateShapesToTosaWhile(op, state);
240 InferShapedTypeOpInterface shapeInterface =
241 dyn_cast<InferShapedTypeOpInterface>(op);
248 .inferReturnTypeComponents(
253 for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
254 Value result = std::get<0>(it);
260 auto currentKnowledge =
265 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
266 inferredKnowledge.hasRank = predictedShape.
hasRank();
267 if (predictedShape.
hasRank()) {
268 for (
auto dim : predictedShape.
getDims()) {
269 inferredKnowledge.sizes.push_back(dim);
280 state.setType(result, newKnowledge.getType());
289 struct TosaInferShapes
290 :
public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
292 void runOnOperation()
override {
293 func::FuncOp func = getOperation();
294 TypeModificationState state;
295 propagateShapesInRegion(func.getBody(), state);
302 return std::make_unique<TosaInferShapes>();
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
StringRef getNamespace() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)