25 #define DEBUG_TYPE "vector-broadcast-lowering"
33 using vector::CombiningKind;
34 enum class KindType { FLOAT, INT, INVALID };
35 KindType type{KindType::INVALID};
37 case CombiningKind::MINNUMF:
38 case CombiningKind::MINIMUMF:
39 case CombiningKind::MAXNUMF:
40 case CombiningKind::MAXIMUMF:
41 type = KindType::FLOAT;
44 case CombiningKind::MINSI:
45 case CombiningKind::MAXUI:
46 case CombiningKind::MAXSI:
47 case CombiningKind::AND:
48 case CombiningKind::OR:
49 case CombiningKind::XOR:
52 case CombiningKind::ADD:
53 case CombiningKind::MUL:
54 type = isInt ? KindType::INT : KindType::FLOAT;
57 bool isValidIntKind = (type == KindType::INT) && isInt;
58 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
59 return (isValidIntKind || isValidFloatKind);
104 LogicalResult matchAndRewrite(vector::ScanOp scanOp,
106 auto loc = scanOp.getLoc();
107 VectorType destType = scanOp.getDestType();
109 auto elType = destType.getElementType();
110 bool isInt = elType.isIntOrIndex();
115 Value result = arith::ConstantOp::create(rewriter, loc, resType,
117 int64_t reductionDim = scanOp.getReductionDim();
118 bool inclusive = scanOp.getInclusive();
119 int64_t destRank = destType.getRank();
120 VectorType initialValueType = scanOp.getInitialValueType();
121 int64_t initialValueRank = initialValueType.getRank();
124 reductionShape[reductionDim] = 1;
129 sizes[reductionDim] = 1;
133 Value lastOutput, lastInput;
134 for (
int i = 0; i < destShape[reductionDim]; i++) {
135 offsets[reductionDim] = i;
137 Value input = vector::ExtractStridedSliceOp::create(
138 rewriter, loc, reductionType, scanOp.getSource(), scanOffsets,
139 scanSizes, scanStrides);
145 if (initialValueRank == 0) {
147 output = vector::BroadcastOp::create(rewriter, loc, input.
getType(),
148 scanOp.getInitialValue());
150 output = vector::ShapeCastOp::create(rewriter, loc, input.
getType(),
151 scanOp.getInitialValue());
155 Value y = inclusive ? input : lastInput;
159 result = vector::InsertStridedSliceOp::create(rewriter, loc, output,
160 result, offsets, strides);
166 if (initialValueRank == 0) {
167 Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0);
169 vector::BroadcastOp::create(rewriter, loc, initialValueType, v);
171 reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType,
175 rewriter.
replaceOp(scanOp, {result, reduction});
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
TypedAttr getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...