MLIR
17.0.0git
|
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
Go to the source code of this file.
Namespaces | |
mlir | |
This header declares functions that assit transformations in the MemRef dialect. | |
mlir::tosa | |
Macros | |
#define | GEN_PASS_DEF_TOSAMAKEBROADCASTABLE |
Functions | |
static LogicalResult | computeReshapeOutput (ArrayRef< int64_t > higherRankShape, ArrayRef< int64_t > lowerRankShape, SmallVectorImpl< int64_t > &reshapeOutputShape) |
There are two potential ways implementing broadcast: a. More... | |
static LogicalResult | reshapeLowerToHigher (PatternRewriter &rewriter, Location loc, RankedTensorType outputType, Value &input1, Value &input2) |
Common code to create the reshape op where necessary to make the rank of the operations equal. More... | |
#define GEN_PASS_DEF_TOSAMAKEBROADCASTABLE |
Definition at line 23 of file TosaMakeBroadcastable.cpp.
|
static |
There are two potential ways implementing broadcast: a.
https://www.tensorflow.org/xla/broadcasting#formal_definition b. https://numpy.org/doc/stable/user/basics.broadcasting.html This pass implements b (numpy style) now. In this pass, we insert RESHAPE operators to increase the rank of the lower rank operand as a first step in the broadcasting process. The TOSA operators that support broadcast require that the rank of the operands are equal.
Definition at line 49 of file TosaMakeBroadcastable.cpp.
References mlir::failure(), and mlir::success().
Referenced by reshapeLowerToHigher().
|
static |
Common code to create the reshape op where necessary to make the rank of the operations equal.
input1 and input2 will be updated when the rank has changed. The caller is expected to use these to rewrite the original operator with the RESHAPE now in the graph.
Definition at line 81 of file TosaMakeBroadcastable.cpp.
References mlir::Type::cast(), computeReshapeOutput(), mlir::OpBuilder::create(), mlir::Type::dyn_cast(), mlir::failed(), mlir::Builder::getDenseI64ArrayAttr(), getShape(), mlir::Value::getType(), mlir::RewriterBase::notifyMatchFailure(), and mlir::success().