25#define DEBUG_TYPE "vector-broadcast-lowering"
32static bool isValidKind(
bool isInt, vector::CombiningKind kind) {
33 using vector::CombiningKind;
34 enum class KindType { FLOAT, INT, INVALID };
35 KindType type{KindType::INVALID};
37 case CombiningKind::MINNUMF:
38 case CombiningKind::MINIMUMF:
39 case CombiningKind::MAXNUMF:
40 case CombiningKind::MAXIMUMF:
41 type = KindType::FLOAT;
43 case CombiningKind::MINUI:
44 case CombiningKind::MINSI:
45 case CombiningKind::MAXUI:
46 case CombiningKind::MAXSI:
47 case CombiningKind::AND:
48 case CombiningKind::OR:
49 case CombiningKind::XOR:
52 case CombiningKind::ADD:
53 case CombiningKind::MUL:
54 type = isInt ? KindType::INT : KindType::FLOAT;
57 bool isValidIntKind = (type == KindType::INT) && isInt;
58 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
59 return (isValidIntKind || isValidFloatKind);
104 LogicalResult matchAndRewrite(vector::ScanOp scanOp,
105 PatternRewriter &rewriter)
const override {
106 auto loc = scanOp.getLoc();
107 VectorType destType = scanOp.getDestType();
108 ArrayRef<int64_t> destShape = destType.getShape();
109 auto elType = destType.getElementType();
110 bool isInt = elType.isIntOrIndex();
114 VectorType resType = destType;
115 Value
result = arith::ConstantOp::create(rewriter, loc, resType,
117 int64_t reductionDim = scanOp.getReductionDim();
118 bool inclusive = scanOp.getInclusive();
119 int64_t destRank = destType.getRank();
120 VectorType initialValueType = scanOp.getInitialValueType();
121 int64_t initialValueRank = initialValueType.getRank();
123 SmallVector<int64_t> reductionShape(destShape);
124 SmallVector<bool> reductionScalableDims(destType.getScalableDims());
126 if (reductionScalableDims[reductionDim])
128 scanOp,
"Trying to reduce scalable dimension - not yet supported!");
132 reductionShape[reductionDim] = 1;
133 VectorType reductionType =
134 VectorType::get(reductionShape, elType, reductionScalableDims);
136 SmallVector<int64_t> offsets(destRank, 0);
137 SmallVector<int64_t> strides(destRank, 1);
138 SmallVector<int64_t> sizes(destShape);
139 sizes[reductionDim] = 1;
143 Value lastOutput, lastInput;
144 for (
int i = 0; i < destShape[reductionDim]; i++) {
145 offsets[reductionDim] = i;
147 Value input = vector::ExtractStridedSliceOp::create(
148 rewriter, loc, reductionType, scanOp.getSource(), scanOffsets,
149 scanSizes, scanStrides);
155 if (initialValueRank == 0) {
157 output = vector::BroadcastOp::create(rewriter, loc, input.
getType(),
158 scanOp.getInitialValue());
160 output = vector::ShapeCastOp::create(rewriter, loc, input.
getType(),
161 scanOp.getInitialValue());
165 Value y = inclusive ? input : lastInput;
169 result = vector::InsertStridedSliceOp::create(rewriter, loc, output,
170 result, offsets, strides);
176 if (initialValueRank == 0) {
177 Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0);
179 vector::BroadcastOp::create(rewriter, loc, initialValueType, v);
181 reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType,
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
TypedAttr getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...