28 #define GEN_PASS_DEF_TOSAINFERSHAPES
29 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
58 class TypeModificationState {
60 TypeModificationState() =
default;
62 ~TypeModificationState() {
64 assert(oldTypes.empty() &&
"unhandled type modifications");
70 oldTypes.emplace_back(value, value.
getType());
78 for (
auto [value, type] : oldTypes)
90 for (
auto [value, oldType] : oldTypes) {
91 tensor::CastOp castedValue;
92 for (
auto &use : value.
getUses()) {
93 if (canBeRefined(use.getOwner()))
99 castedValue = builder.create<tensor::CastOp>(oldType, value);
102 use.set(castedValue);
115 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state);
117 void propagateShapesToTosaIf(
Operation &op, TypeModificationState &state) {
118 IfOp ifOp = dyn_cast<IfOp>(op);
130 auto oldType = cast<ShapedType>(blockArg.getType());
132 if (inferredTy.hasRank()) {
133 Type newType = oldType.clone(inferredTy.getShape());
134 state.setType(blockArg, newType);
140 ifOp.getOperand(i + 1).getType());
145 if (!joinedKnowledge)
150 propagateShapesInRegion(region, state);
154 void propagateShapesToTosaWhile(
Operation &op, TypeModificationState &state) {
155 WhileOp whileOp = dyn_cast<WhileOp>(op);
164 bool hasNewTypes =
true;
165 while (hasNewTypes) {
166 TypeModificationState localState;
171 for (
int i = 0, s = argTypes.size(); i < s; i++) {
172 localState.setType(block.
getArgument(i), argTypes[i]);
176 propagateShapesInRegion(bodyRegion, localState);
180 for (
auto &block : bodyRegion)
182 yieldOps.push_back(yieldOp);
184 assert(yieldOps.size() == 1 &&
"missing or non-unique yield op");
187 for (
auto ty : argTypes) {
191 for (
auto yieldOp : yieldOps) {
195 yieldTypeInfo[it.index()] =
201 if (yieldTypeInfo.size() != argTypes.size()) {
202 op.
emitWarning(
"has a tosa.yield with the incorrect number of operands");
208 for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
209 Type newType = yieldTypeInfo[i].getType();
210 hasNewTypes |= (newType != argTypes[i]);
211 argTypes[i] = newType;
215 localState.rollBack();
222 for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
223 state.setType(region.front().getArgument(i), argTypes[i]);
226 propagateShapesInRegion(region, state);
230 void propagateShapesInRegion(
Region ®ion, TypeModificationState &state) {
233 for (
auto &block : region) {
238 propagateShapesToTosaIf(op, state);
239 propagateShapesToTosaWhile(op, state);
241 InferShapedTypeOpInterface shapeInterface =
242 dyn_cast<InferShapedTypeOpInterface>(op);
249 .inferReturnTypeComponents(
254 for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
255 Value result = std::get<0>(it);
261 auto currentKnowledge =
266 inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
267 inferredKnowledge.hasRank = predictedShape.
hasRank();
268 if (predictedShape.
hasRank()) {
269 for (
auto dim : predictedShape.
getDims()) {
270 inferredKnowledge.sizes.push_back(dim);
281 state.setType(result, newKnowledge.getType());
290 struct TosaInferShapes
291 :
public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
293 void runOnOperation()
override {
294 func::FuncOp func = getOperation();
295 TypeModificationState state;
296 propagateShapesInRegion(func.getBody(), state);
303 return std::make_unique<TosaInferShapes>();
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_type_range getOperandTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
ShapedTypeComponents that represents the components of a ShapedType.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Location getLoc() const
Return the location of this value.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::unique_ptr< Pass > createTosaInferShapesPass()
Include the generated interface declarations.
Statically known information for a particular Value.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)