35 #define DEBUG_TYPE "vector-broadcast-lowering"
62 LogicalResult matchAndRewrite(vector::GatherOp op,
64 VectorType resultTy = op.getType();
65 if (resultTy.getRank() < 2)
70 if (resultTy.getScalableDims().front())
74 Value indexVec = op.getIndexVec();
75 Value maskVec = op.getMask();
76 Value passThruVec = op.getPassThru();
83 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
84 int64_t thisIdx[1] = {i};
87 rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
89 rewriter.
create<vector::ExtractOp>(loc, maskVec, thisIdx);
90 Value passThruSubVec =
91 rewriter.
create<vector::ExtractOp>(loc, passThruVec, thisIdx);
92 Value subGather = rewriter.
create<vector::GatherOp>(
93 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
96 rewriter.
create<vector::InsertOp>(loc, subGather, result, thisIdx);
127 LogicalResult matchAndRewrite(vector::GatherOp op,
129 Value base = op.getBase();
136 auto sourceType = subview.getSource().getType();
139 if (sourceType.getRank() != 2)
143 auto layout = subview.getResult().getType().getLayout();
144 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
145 if (!stridedLayoutAttr)
149 if (stridedLayoutAttr.getStrides().size() != 1)
152 int64_t srcTrailingDim = sourceType.getShape().back();
157 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
162 Value collapsed = rewriter.
create<memref::CollapseShapeOp>(
163 op.getLoc(), subview.getSource(), reassoc);
167 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
168 VectorType vType = op.getIndexVec().getType();
173 rewriter.
create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
177 Value newGather = rewriter.
create<vector::GatherOp>(
178 op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
179 newIdxs, op.getMask(), op.getPassThru());
192 LogicalResult matchAndRewrite(vector::GatherOp op,
194 VectorType resultTy = op.getType();
195 if (resultTy.getRank() != 1)
198 if (resultTy.isScalable())
202 Type elemTy = resultTy.getElementType();
206 Value condMask = op.getMask();
207 Value base = op.getBase();
211 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
212 if (
auto stridesAttr =
213 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
214 if (stridesAttr.getStrides().back() != 1 &&
215 resultTy.getNumElements() != 1)
221 loc, op.getIndexVectorType().clone(rewriter.
getIndexType()),
223 auto baseOffsets = llvm::to_vector(op.getIndices());
224 Value lastBaseOffset = baseOffsets.back();
226 Value result = op.getPassThru();
229 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
230 int64_t thisIdx[1] = {i};
232 rewriter.
create<vector::ExtractOp>(loc, condMask, thisIdx);
233 Value index = rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
235 rewriter.
createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
239 if (isa<MemRefType>(base.
getType())) {
243 b.
create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
244 int64_t zeroIdx[1] = {0};
245 extracted = b.
create<vector::ExtractOp>(loc, load, zeroIdx);
247 extracted = b.
create<tensor::ExtractOp>(loc, base, baseOffsets);
251 b.
create<vector::InsertOp>(loc, extracted, result, thisIdx);
252 b.
create<scf::YieldOp>(loc, newResult);
255 b.
create<scf::YieldOp>(loc, result);
260 .
create<scf::IfOp>(loc, condition, loadBuilder,
278 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...