36 #define DEBUG_TYPE "vector-broadcast-lowering"
43 static bool isValidKind(
bool isInt, vector::CombiningKind kind) {
44 using vector::CombiningKind;
45 enum class KindType { FLOAT, INT, INVALID };
46 KindType type{KindType::INVALID};
48 case CombiningKind::MINNUMF:
49 case CombiningKind::MINIMUMF:
50 case CombiningKind::MAXNUMF:
51 case CombiningKind::MAXIMUMF:
52 type = KindType::FLOAT;
55 case CombiningKind::MINSI:
56 case CombiningKind::MAXUI:
57 case CombiningKind::MAXSI:
58 case CombiningKind::AND:
59 case CombiningKind::OR:
60 case CombiningKind::XOR:
63 case CombiningKind::ADD:
64 case CombiningKind::MUL:
65 type = isInt ? KindType::INT : KindType::FLOAT;
68 bool isValidIntKind = (type == KindType::INT) && isInt;
69 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
70 return (isValidIntKind || isValidFloatKind);
117 auto loc = scanOp.getLoc();
118 VectorType destType = scanOp.getDestType();
120 auto elType = destType.getElementType();
121 bool isInt = elType.isIntOrIndex();
128 int64_t reductionDim = scanOp.getReductionDim();
129 bool inclusive = scanOp.getInclusive();
130 int64_t destRank = destType.getRank();
131 VectorType initialValueType = scanOp.getInitialValueType();
132 int64_t initialValueRank = initialValueType.getRank();
135 reductionShape[reductionDim] = 1;
140 sizes[reductionDim] = 1;
144 Value lastOutput, lastInput;
145 for (
int i = 0; i < destShape[reductionDim]; i++) {
146 offsets[reductionDim] = i;
148 Value input = rewriter.
create<vector::ExtractStridedSliceOp>(
149 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
156 if (initialValueRank == 0) {
158 output = rewriter.
create<vector::BroadcastOp>(
159 loc, input.
getType(), scanOp.getInitialValue());
161 output = rewriter.
create<vector::ShapeCastOp>(
162 loc, input.
getType(), scanOp.getInitialValue());
166 Value y = inclusive ? input : lastInput;
170 result = rewriter.
create<vector::InsertStridedSliceOp>(
171 loc, output, result, offsets, strides);
177 if (initialValueRank == 0) {
178 Value v = rewriter.
create<vector::ExtractOp>(loc, lastOutput, 0);
180 rewriter.
create<vector::BroadcastOp>(loc, initialValueType, v);
182 reduction = rewriter.
create<vector::ShapeCastOp>(loc, initialValueType,
186 rewriter.
replaceOp(scanOp, {result, reduction});
194 patterns.
add<ScanToArithOps>(patterns.
getContext(), benefit);
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
TypedAttr getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...