37 #define DEBUG_TYPE "vector-broadcast-lowering"
51 VectorType dstType = op.getResultVectorType();
52 VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
53 Type eltType = dstType.getElementType();
62 int64_t srcRank = srcType.getRank();
63 int64_t dstRank = dstType.getRank();
66 if (srcRank <= 1 && dstRank == 1) {
69 ext = rewriter.
create<vector::ExtractElementOp>(loc, op.getSource());
71 ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), 0);
85 if (srcRank < dstRank) {
89 rewriter.
create<vector::BroadcastOp>(loc, resType, op.getSource());
92 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
93 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
99 assert(srcRank == dstRank);
101 for (int64_t r = 0; r < dstRank; r++)
102 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
134 Value ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), 0);
135 Value bcst = rewriter.
create<vector::BroadcastOp>(loc, resType, ext);
136 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
137 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
140 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
141 Value ext = rewriter.
create<vector::ExtractOp>(loc, op.getSource(), d);
142 Value bcst = rewriter.
create<vector::BroadcastOp>(loc, resType, ext);
143 result = rewriter.
create<vector::InsertOp>(loc, bcst, result, d);
154 patterns.
add<BroadcastOpLowering>(patterns.
getContext(), benefit);
TypedAttr getZeroAttr(Type type)
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...