28 #define DEBUG_TYPE "vector-broadcast-lowering"
55 LogicalResult matchAndRewrite(vector::GatherOp op,
57 Value indexVec = op.getIndices();
58 Value maskVec = op.getMask();
59 Value passThruVec = op.getPassThru();
62 VectorType subTy, int64_t index) {
63 int64_t thisIdx[1] = {index};
66 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
68 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
69 Value passThruSubVec =
70 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
71 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
72 op.getOffsets(), indexSubVec, maskSubVec,
103 LogicalResult matchAndRewrite(vector::GatherOp op,
105 Value base = op.getBase();
112 auto sourceType = subview.getSource().getType();
115 if (sourceType.getRank() != 2)
119 auto layout = subview.getResult().getType().getLayout();
120 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
121 if (!stridedLayoutAttr)
125 if (stridedLayoutAttr.getStrides().size() != 1)
128 int64_t srcTrailingDim = sourceType.getShape().back();
133 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
138 Value collapsed = memref::CollapseShapeOp::create(
139 rewriter, op.getLoc(), subview.getSource(), reassoc);
143 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
144 VectorType vType = op.getIndices().getType();
145 Value mulCst = arith::ConstantOp::create(
149 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
153 Value newGather = vector::GatherOp::create(
154 rewriter, op.getLoc(), op.getResult().
getType(), collapsed,
155 op.getOffsets(), newIdxs, op.getMask(), op.getPassThru());
168 LogicalResult matchAndRewrite(vector::GatherOp op,
170 VectorType resultTy = op.getType();
171 if (resultTy.getRank() != 1)
174 if (resultTy.isScalable())
178 Type elemTy = resultTy.getElementType();
182 Value condMask = op.getMask();
183 Value base = op.getBase();
187 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
188 if (
auto stridesAttr =
189 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
190 if (stridesAttr.getStrides().back() != 1 &&
191 resultTy.getNumElements() != 1)
197 loc, op.getIndexVectorType().clone(rewriter.
getIndexType()),
199 auto baseOffsets = llvm::to_vector(op.getOffsets());
200 Value lastBaseOffset = baseOffsets.back();
202 Value result = op.getPassThru();
205 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
206 int64_t thisIdx[1] = {i};
208 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
209 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
211 rewriter.
createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
215 if (isa<MemRefType>(base.
getType())) {
219 vector::LoadOp::create(b, loc, elemVecTy, base, baseOffsets);
220 int64_t zeroIdx[1] = {0};
221 extracted = vector::ExtractOp::create(b, loc, load, zeroIdx);
223 extracted = tensor::ExtractOp::create(b, loc, base, baseOffsets);
227 vector::InsertOp::create(b, loc, extracted, result, thisIdx);
228 scf::YieldOp::create(b, loc, newResult);
231 scf::YieldOp::create(b, loc, result);
234 result = scf::IfOp::create(rewriter, loc, condition,
253 patterns.add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...