36 #define DEBUG_TYPE "vector-broadcast-lowering"
47 vector::CombiningKind kind,
49 using vector::CombiningKind;
51 auto elType = cast<VectorType>(x.
getType()).getElementType();
52 bool isInt = elType.isIntOrIndex();
54 Value combinedResult{
nullptr};
56 case CombiningKind::ADD:
58 combinedResult = rewriter.
create<arith::AddIOp>(loc, x, y);
60 combinedResult = rewriter.
create<arith::AddFOp>(loc, x, y);
62 case CombiningKind::MUL:
64 combinedResult = rewriter.
create<arith::MulIOp>(loc, x, y);
66 combinedResult = rewriter.
create<arith::MulFOp>(loc, x, y);
68 case CombiningKind::MINUI:
69 combinedResult = rewriter.
create<arith::MinUIOp>(loc, x, y);
71 case CombiningKind::MINSI:
72 combinedResult = rewriter.
create<arith::MinSIOp>(loc, x, y);
74 case CombiningKind::MAXUI:
75 combinedResult = rewriter.
create<arith::MaxUIOp>(loc, x, y);
77 case CombiningKind::MAXSI:
78 combinedResult = rewriter.
create<arith::MaxSIOp>(loc, x, y);
80 case CombiningKind::AND:
81 combinedResult = rewriter.
create<arith::AndIOp>(loc, x, y);
83 case CombiningKind::OR:
84 combinedResult = rewriter.
create<arith::OrIOp>(loc, x, y);
86 case CombiningKind::XOR:
87 combinedResult = rewriter.
create<arith::XOrIOp>(loc, x, y);
89 case CombiningKind::MINF:
90 case CombiningKind::MINIMUMF:
91 combinedResult = rewriter.
create<arith::MinimumFOp>(loc, x, y);
93 case CombiningKind::MAXF:
94 case CombiningKind::MAXIMUMF:
95 combinedResult = rewriter.
create<arith::MaximumFOp>(loc, x, y);
98 return combinedResult;
104 using vector::CombiningKind;
105 enum class KindType { FLOAT, INT, INVALID };
106 KindType type{KindType::INVALID};
108 case CombiningKind::MINF:
109 case CombiningKind::MINIMUMF:
110 case CombiningKind::MAXF:
111 case CombiningKind::MAXIMUMF:
112 type = KindType::FLOAT;
114 case CombiningKind::MINUI:
115 case CombiningKind::MINSI:
116 case CombiningKind::MAXUI:
117 case CombiningKind::MAXSI:
118 case CombiningKind::AND:
119 case CombiningKind::OR:
120 case CombiningKind::XOR:
121 type = KindType::INT;
123 case CombiningKind::ADD:
124 case CombiningKind::MUL:
125 type = isInt ? KindType::INT : KindType::FLOAT;
128 bool isValidIntKind = (type == KindType::INT) && isInt;
129 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
130 return (isValidIntKind || isValidFloatKind);
177 auto loc = scanOp.getLoc();
178 VectorType destType = scanOp.getDestType();
180 auto elType = destType.getElementType();
181 bool isInt = elType.isIntOrIndex();
188 int64_t reductionDim = scanOp.getReductionDim();
189 bool inclusive = scanOp.getInclusive();
190 int64_t destRank = destType.getRank();
191 VectorType initialValueType = scanOp.getInitialValueType();
192 int64_t initialValueRank = initialValueType.getRank();
195 reductionShape[reductionDim] = 1;
200 sizes[reductionDim] = 1;
204 Value lastOutput, lastInput;
205 for (
int i = 0; i < destShape[reductionDim]; i++) {
206 offsets[reductionDim] = i;
208 Value input = rewriter.
create<vector::ExtractStridedSliceOp>(
209 loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
216 if (initialValueRank == 0) {
218 output = rewriter.
create<vector::BroadcastOp>(
219 loc, input.
getType(), scanOp.getInitialValue());
221 output = rewriter.
create<vector::ShapeCastOp>(
222 loc, input.
getType(), scanOp.getInitialValue());
226 Value y = inclusive ? input : lastInput;
227 output =
genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
228 assert(output !=
nullptr);
230 result = rewriter.
create<vector::InsertStridedSliceOp>(
231 loc, output, result, offsets, strides);
237 if (initialValueRank == 0) {
238 Value v = rewriter.
create<vector::ExtractOp>(loc, lastOutput, 0);
240 rewriter.
create<vector::BroadcastOp>(loc, initialValueType, v);
242 reduction = rewriter.
create<vector::ShapeCastOp>(loc, initialValueType,
246 rewriter.
replaceOp(scanOp, {result, reduction});
254 patterns.
add<ScanToArithOps>(patterns.
getContext(), benefit);
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
static Value genOperator(Location loc, Value x, Value y, vector::CombiningKind kind, PatternRewriter &rewriter)
This function constructs the appropriate integer or float operation given the vector combining kind a...
TypedAttr getZeroAttr(Type type)
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...