MLIR  17.0.0git
Namespaces | Macros | Functions
TosaMakeBroadcastable.cpp File Reference
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
+ Include dependency graph for TosaMakeBroadcastable.cpp:

Go to the source code of this file.

Namespaces

 mlir
 This header declares functions that assit transformations in the MemRef dialect.
 
 mlir::tosa
 

Macros

#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE
 

Functions

static LogicalResult computeReshapeOutput (ArrayRef< int64_t > higherRankShape, ArrayRef< int64_t > lowerRankShape, SmallVectorImpl< int64_t > &reshapeOutputShape)
 There are two potential ways implementing broadcast: a. More...
 
static LogicalResult reshapeLowerToHigher (PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value &input1, Value &input2)
 Common code to create the reshape op where necessary to make the rank of the operations equal. More...
 

Macro Definition Documentation

◆ GEN_PASS_DEF_TOSAMAKEBROADCASTABLE

#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE

Definition at line 23 of file TosaMakeBroadcastable.cpp.

Function Documentation

◆ computeReshapeOutput()

static LogicalResult computeReshapeOutput ( ArrayRef< int64_t >  higherRankShape,
ArrayRef< int64_t >  lowerRankShape,
SmallVectorImpl< int64_t > &  reshapeOutputShape 
)
static

There are two potential ways implementing broadcast: a.

https://www.tensorflow.org/xla/broadcasting#formal_definition b. https://numpy.org/doc/stable/user/basics.broadcasting.html This pass implements b (numpy style) now. In this pass, we insert RESHAPE operators to increase the rank of the lower rank operand as a first step in the broadcasting process. The TOSA operators that support broadcast require that the rank of the operands are equal.

Definition at line 49 of file TosaMakeBroadcastable.cpp.

References mlir::failure(), and mlir::success().

Referenced by reshapeLowerToHigher().

◆ reshapeLowerToHigher()

static LogicalResult reshapeLowerToHigher ( PatternRewriter rewriter,
Location  loc,
RankedTensorType  outputType,
Value input1,
Value input2 
)
static

Common code to create the reshape op where necessary to make the rank of the operations equal.

input1 and input2 will be updated when the rank has changed. The caller is expected to use these to rewrite the original operator with the RESHAPE now in the graph.

Definition at line 81 of file TosaMakeBroadcastable.cpp.

References mlir::Type::cast(), computeReshapeOutput(), mlir::OpBuilder::create(), mlir::Type::dyn_cast(), mlir::failed(), mlir::Builder::getDenseI64ArrayAttr(), getShape(), mlir::Value::getType(), mlir::RewriterBase::notifyMatchFailure(), and mlir::success().