MLIR
15.0.0git
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR//TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Go to the source code of this file.
Functions | |
static LogicalResult | computeReshapeOutput (ArrayRef< int64_t > higherRankShape, ArrayRef< int64_t > lowerRankShape, SmallVectorImpl< int64_t > &reshapeOutputShape) |
There are two potential ways implementing broadcast: a. More... | |
static LogicalResult | reshapeLowerToHigher (PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value input1, Value input2, Value &outInput1, Value &outInput2) |
Common code to create the reshape op where necessary to make the rank of the operations equal. More... | |
|
static |
There are two potential ways implementing broadcast: a.
https://www.tensorflow.org/xla/broadcasting#formal_definition b. https://numpy.org/doc/stable/user/basics.broadcasting.html This pass implements b (numpy style) now. In this pass, we insert RESHAPE operators to increase the rank of the lower rank operand as a first step in the broadcasting process. The TOSA operators that support broadcast require that the rank of the operands are equal.
Definition at line 42 of file TosaMakeBroadcastable.cpp.
References mlir::failure(), and mlir::success().
Referenced by reshapeLowerToHigher().
|
static |
Common code to create the reshape op where necessary to make the rank of the operations equal.
Returns the updated input1 and input2 for the original input. The caller is expected to use these to rewrite the original operator with the RESHAPE now in the graph.
Definition at line 74 of file TosaMakeBroadcastable.cpp.
References mlir::applyPatternsAndFoldGreedily(), mlir::Type::cast(), computeReshapeOutput(), mlir::OpBuilder::create(), mlir::Type::dyn_cast(), mlir::failed(), mlir::failure(), mlir::Builder::getI64ArrayAttr(), getShape(), mlir::Value::getType(), mlir::RewriterBase::replaceOpWithNewOp(), and mlir::success().