25#define DEBUG_TYPE "vector-broadcast-lowering"
39 LogicalResult matchAndRewrite(vector::BroadcastOp op,
41 auto loc = op.getLoc();
42 VectorType dstType = op.getResultVectorType();
43 VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
44 Type eltType = dstType.getElementType();
49 op,
"broadcast from scalar already in lowered form");
52 int64_t srcRank = srcType.getRank();
53 int64_t dstRank = dstType.getRank();
57 if (srcRank <= 1 && dstRank == 1) {
59 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
61 assert(!isa<VectorType>(ext.
getType()) &&
"expected scalar");
75 if (srcRank < dstRank) {
79 vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
80 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
81 for (
int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
82 result = vector::InsertOp::create(rewriter, loc, bcst,
result, d);
88 assert(srcRank == dstRank);
90 for (
int64_t r = 0; r < dstRank; r++)
91 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
118 VectorType::get(dstType.getShape().drop_front(), eltType,
119 dstType.getScalableDims().drop_front());
125 if (m != 0 && dstType.getScalableDims()[0]) {
130 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
133 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
134 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
135 for (
int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
136 result = vector::InsertOp::create(rewriter, loc, bcst,
result, d);
139 for (
int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
140 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
141 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
142 result = vector::InsertOp::create(rewriter, loc, bcst,
result, d);
153 patterns.
add<BroadcastOpLowering>(patterns.
getContext(), benefit);
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...