28#define GEN_PASS_DEF_TOSAINFERSHAPESPASS
29#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
47 isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
58class TypeModificationState {
60 TypeModificationState() =
default;
62 ~TypeModificationState() {
64 assert(oldTypes.empty() &&
"unhandled type modifications");
68 void setType(Value value, Type type) {
70 oldTypes.emplace_back(value, value.
getType());
78 for (
auto [value, type] : oldTypes)
90 for (
auto [value, oldType] : oldTypes) {
94 llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
96 [](OpOperand &use) -> OpOperand * {
102 tensor::CastOp castValue;
105 for (OpOperand *use : uses) {
106 if (canBeRefined(use->getOwner()))
123 tensor::CastOp::create(builder, value.
getLoc(), oldType, value);
136 llvm::SmallVector<std::pair<Value, Type>> oldTypes;
141void validateSameOperandsAndResultRankTrait(
Region ®ion) {
143 for (
auto &block : region) {
144 for (
auto &op : block) {
145 if (!op.getDialect() ||
146 op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
154 WhileOp whileOp = dyn_cast<WhileOp>(op);
155 IfOp ifOp = dyn_cast<IfOp>(op);
156 if (whileOp || ifOp) {
158 for (
auto &next : op.getRegions()) {
159 validateSameOperandsAndResultRankTrait(next);
168struct TosaInferShapes
169 :
public tosa::impl::TosaInferShapesPassBase<TosaInferShapes> {
171 explicit TosaInferShapes() =
default;
172 explicit TosaInferShapes(
const TosaInferShapesPassOptions &
options)
173 : TosaInferShapes() {
174 this->foldShapeExpressions =
options.foldShapeExpressions;
175 this->convertFunctionBoundaries =
options.convertFunctionBoundaries;
178 void runOnOperation()
override {
179 func::FuncOp func = getOperation();
180 TypeModificationState state;
181 propagateShapesInRegion(func.getBody(), state);
184 if (foldShapeExpressions) {
186 func.walk<WalkOrder::PostOrder, ReverseIterator>(
187 [](tosa::ConstShapeOp op) {
193 validateSameOperandsAndResultRankTrait(func.getBody());
195 if (convertFunctionBoundaries)
196 convertFunctionReturnTypes(func);
200 void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
201 IfOp ifOp = dyn_cast<IfOp>(op);
206 Block &frontBlock = region.front();
213 auto oldType = cast<ShapedType>(blockArg.getType());
215 if (inferredTy.hasRank()) {
216 Type newType = oldType.clone(inferredTy.getShape());
217 state.setType(blockArg, newType);
223 ifOp.getOperand(i + 1).getType());
226 ValueKnowledge joinedKnowledge =
228 if (!joinedKnowledge)
233 propagateShapesInRegion(region, state);
237 void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
238 WhileOp whileOp = dyn_cast<WhileOp>(op);
247 bool hasNewTypes =
true;
248 while (hasNewTypes) {
249 TypeModificationState localState;
254 for (
int i = 0, s = argTypes.size(); i < s; i++) {
255 localState.setType(block.
getArgument(i), argTypes[i]);
259 propagateShapesInRegion(bodyRegion, localState);
262 llvm::SmallVector<YieldOp> yieldOps;
263 for (
auto &block : bodyRegion)
265 yieldOps.push_back(yieldOp);
267 assert(yieldOps.size() == 1 &&
"missing or non-unique yield op");
269 llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
270 for (
auto ty : argTypes) {
274 for (
auto yieldOp : yieldOps) {
275 for (
const auto &it : llvm::enumerate(yieldOp.getOperands())) {
278 yieldTypeInfo[it.index()] =
284 if (yieldTypeInfo.size() != argTypes.size()) {
286 "has a tosa.yield with the incorrect number of operands");
292 for (
int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
293 Type newType = yieldTypeInfo[i].getType();
294 hasNewTypes |= (newType != argTypes[i]);
295 argTypes[i] = newType;
300 localState.rollBack();
307 for (
unsigned int i = 0, s = argTypes.size(); i < s; i++) {
308 state.setType(region.front().getArgument(i), argTypes[i]);
311 propagateShapesInRegion(region, state);
315 void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
318 OperationFolder folder(ctx);
320 for (
auto &block : region) {
324 for (
auto it = block.
begin(); it != block.
end();) {
325 Operation &op = *it++;
329 propagateShapesToTosaIf(op, state);
330 propagateShapesToTosaWhile(op, state);
332 if (foldShapeExpressions &&
333 op.
hasTrait<OpTrait::tosa::TosaShapeOperator>()) {
334 (void)folder.tryToFold(&op);
338 InferShapedTypeOpInterface shapeInterface =
339 dyn_cast<InferShapedTypeOpInterface>(op);
343 SmallVector<ShapedTypeComponents> returnedShapes;
346 .inferReturnTypeComponents(
351 for (
auto it : llvm::zip(op.
getResults(), returnedShapes)) {
352 Value
result = std::get<0>(it);
353 ShapedTypeComponents predictedShape = std::get<1>(it);
357 Type resultTy =
result.getType();
358 auto currentKnowledge =
363 inferredKnowledge.dtype =
364 cast<ShapedType>(resultTy).getElementType();
365 inferredKnowledge.hasRank = predictedShape.
hasRank();
366 if (predictedShape.
hasRank()) {
367 for (
auto dim : predictedShape.
getDims()) {
368 inferredKnowledge.sizes.push_back(dim);
379 state.setType(
result, newKnowledge.getType());
386 void convertFunctionReturnTypes(func::FuncOp func) {
387 IRRewriter rewriter(func.getContext());
388 SmallVector<Type> newReturnTypes;
391 func.walk([&rewriter, &newReturnTypes](func::ReturnOp ret) {
392 SmallVector<Value> newReturnValues;
393 SmallVector<Value> maybeDeadCasts;
394 OperandRange returnOperands = ret.getOperands();
395 newReturnValues.reserve(returnOperands.size());
396 maybeDeadCasts.reserve(returnOperands.size());
397 newReturnTypes.reserve(newReturnTypes.size() + returnOperands.size());
399 for (
const Value &v : returnOperands) {
400 Value newReturnValue = v;
401 if (
auto castOp = v.getDefiningOp<tensor::CastOp>()) {
402 newReturnValue = castOp.getSource();
403 maybeDeadCasts.push_back(castOp);
405 newReturnValues.push_back(newReturnValue);
406 newReturnTypes.push_back(newReturnValue.
getType());
409 rewriter.setInsertionPoint(ret);
410 rewriter.replaceOpWithNewOp<func::ReturnOp>(ret, newReturnValues);
412 if (!maybeDeadCasts.empty()) {
413 llvm::for_each(maybeDeadCasts, [&](Value castVal) {
422 const FunctionType oldType = func.getFunctionType();
423 const FunctionType newType = FunctionType::get(
424 func.getContext(), oldType.getInputs(), newReturnTypes);
425 func.setType(newType);
static llvm::ManagedStatic< PassManagerOptions > options
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
TypeID getTypeID() const
Returns the unique identifier that corresponds to this dialect.
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class verifies that op has same ranks for all operands and results types, if known.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
InFlightDiagnostic emitWarning(const Twine &message={})
Emit a warning about this operation, reporting up to any diagnostic handlers that may be listening.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
DictionaryAttr getDiscardableAttrDictionary()
Return all of the discardable attributes on this operation as a DictionaryAttr.
operand_type_range getOperandTypes()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
MLIRContext * getContext()
Return the context this operation is associated with.
OpaqueProperties getPropertiesStorage()
Returns the properties storage.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
MLIRContext * getContext()
Return the context this region is inserted in.
bool hasRank() const
Return whether the shape has a rank.
ArrayRef< int64_t > getDims() const
Return the dimensions of the shape.
static TypeID get()
Construct a type info object for the given type T.
bool use_empty() const
Returns true if this value has no uses.
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult verifySameOperandsAndResultRank(Operation *op)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
static ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getPessimisticValueState()
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
static ValueKnowledge getKnowledgeFromType(Type type)