MLIR  15.0.0git
Functions
TosaMakeBroadcastable.cpp File Reference
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+ Include dependency graph for TosaMakeBroadcastable.cpp:

Go to the source code of this file.

Functions

static LogicalResult computeReshapeOutput (ArrayRef< int64_t > higherRankShape, ArrayRef< int64_t > lowerRankShape, SmallVectorImpl< int64_t > &reshapeOutputShape)
 There are two potential ways implementing broadcast: a. More...
 
static LogicalResult reshapeLowerToHigher (PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value input1, Value input2, Value &outInput1, Value &outInput2)
 Common code to create the reshape op where necessary to make the rank of the operations equal. More...
 

Function Documentation

◆ computeReshapeOutput()

static LogicalResult computeReshapeOutput ( ArrayRef< int64_t >  higherRankShape,
ArrayRef< int64_t >  lowerRankShape,
SmallVectorImpl< int64_t > &  reshapeOutputShape 
)
static

There are two potential ways implementing broadcast: a.

https://www.tensorflow.org/xla/broadcasting#formal_definition b. https://numpy.org/doc/stable/user/basics.broadcasting.html This pass implements b (numpy style) now. In this pass, we insert RESHAPE operators to increase the rank of the lower rank operand as a first step in the broadcasting process. The TOSA operators that support broadcast require that the rank of the operands are equal.

Definition at line 42 of file TosaMakeBroadcastable.cpp.

References mlir::failure(), and mlir::success().

Referenced by reshapeLowerToHigher().

◆ reshapeLowerToHigher()

static LogicalResult reshapeLowerToHigher ( PatternRewriter rewriter,
Location  loc,
RankedTensorType  outputType,
Value  input1,
Value  input2,
Value outInput1,
Value outInput2 
)
static

Common code to create the reshape op where necessary to make the rank of the operations equal.

Returns the updated input1 and input2 for the original input. The caller is expected to use these to rewrite the original operator with the RESHAPE now in the graph.

Definition at line 74 of file TosaMakeBroadcastable.cpp.

References mlir::applyPatternsAndFoldGreedily(), mlir::Type::cast(), computeReshapeOutput(), mlir::OpBuilder::create(), mlir::Type::dyn_cast(), mlir::failed(), mlir::failure(), mlir::Builder::getI64ArrayAttr(), getShape(), mlir::Value::getType(), mlir::RewriterBase::replaceOpWithNewOp(), and mlir::success().