25 #define DEBUG_TYPE "vector-broadcast-lowering"
39 LogicalResult matchAndRewrite(vector::BroadcastOp op,
41 auto loc = op.getLoc();
42 VectorType dstType = op.getResultVectorType();
43 VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
44 Type eltType = dstType.getElementType();
49 op,
"broadcast from scalar already in lowered form");
52 int64_t srcRank = srcType.getRank();
53 int64_t dstRank = dstType.getRank();
57 if (srcRank <= 1 && dstRank == 1) {
59 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
61 assert(!isa<VectorType>(ext.
getType()) &&
"expected scalar");
75 if (srcRank < dstRank) {
79 vector::BroadcastOp::create(rewriter, loc, resType, op.getSource());
80 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
81 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
82 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
88 assert(srcRank == dstRank);
90 for (int64_t r = 0; r < dstRank; r++)
91 if (srcType.getDimSize(r) != dstType.getDimSize(r)) {
119 dstType.getScalableDims().drop_front());
120 Value result = ub::PoisonOp::create(rewriter, loc, dstType);
123 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), 0);
124 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
125 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d)
126 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
129 if (dstType.getScalableDims()[0]) {
133 for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
134 Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), d);
135 Value bcst = vector::BroadcastOp::create(rewriter, loc, resType, ext);
136 result = vector::InsertOp::create(rewriter, loc, bcst, result, d);
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...