35 #define DEBUG_TYPE "vector-broadcast-lowering"
62 LogicalResult matchAndRewrite(vector::GatherOp op,
64 VectorType resultTy = op.getType();
65 if (resultTy.getRank() < 2)
70 if (resultTy.getScalableDims().front())
74 Value indexVec = op.getIndexVec();
75 Value maskVec = op.getMask();
76 Value passThruVec = op.getPassThru();
83 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84 int64_t thisIdx[1] = {i};
87 rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
89 rewriter.
create<vector::ExtractOp>(loc, maskVec, thisIdx);
90 Value passThruSubVec =
91 rewriter.
create<vector::ExtractOp>(loc, passThruVec, thisIdx);
92 Value subGather = rewriter.
create<vector::GatherOp>(
93 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
96 rewriter.
create<vector::InsertOp>(loc, subGather, result, thisIdx);
126 LogicalResult matchAndRewrite(vector::GatherOp op,
128 Value base = op.getBase();
135 auto sourceType = subview.getSource().getType();
138 if (sourceType.getRank() != 2)
142 auto layout = subview.getResult().getType().getLayout();
143 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
144 if (!stridedLayoutAttr)
148 if (stridedLayoutAttr.getStrides().size() != 1)
151 int64_t srcTrailingDim = sourceType.getShape().back();
156 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
161 Value collapsed = rewriter.
create<memref::CollapseShapeOp>(
162 op.getLoc(), subview.getSource(), reassoc);
166 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
167 VectorType vType = op.getIndexVec().getType();
172 rewriter.
create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
176 Value newGather = rewriter.
create<vector::GatherOp>(
177 op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
178 newIdxs, op.getMask(), op.getPassThru());
191 LogicalResult matchAndRewrite(vector::GatherOp op,
193 VectorType resultTy = op.getType();
194 if (resultTy.getRank() != 1)
197 if (resultTy.isScalable())
201 Type elemTy = resultTy.getElementType();
205 Value condMask = op.getMask();
206 Value base = op.getBase();
209 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
210 if (
auto stridesAttr =
211 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
212 if (stridesAttr.getStrides().back() != 1)
218 loc, op.getIndexVectorType().clone(rewriter.
getIndexType()),
220 auto baseOffsets = llvm::to_vector(op.getIndices());
221 Value lastBaseOffset = baseOffsets.back();
223 Value result = op.getPassThru();
226 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
227 int64_t thisIdx[1] = {i};
229 rewriter.
create<vector::ExtractOp>(loc, condMask, thisIdx);
230 Value index = rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
232 rewriter.
createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
236 if (isa<MemRefType>(base.
getType())) {
240 b.
create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
241 int64_t zeroIdx[1] = {0};
242 extracted = b.
create<vector::ExtractOp>(loc, load, zeroIdx);
244 extracted = b.
create<tensor::ExtractOp>(loc, base, baseOffsets);
248 b.
create<vector::InsertOp>(loc, extracted, result, thisIdx);
249 b.
create<scf::YieldOp>(loc, newResult);
252 b.
create<scf::YieldOp>(loc, result);
257 .
create<scf::IfOp>(loc, condition, loadBuilder,
270 patterns.
add<FlattenGather, RemoveStrideFromGatherSource,
271 Gather1DToConditionalLoads>(patterns.
getContext(), benefit);
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...