29#define DEBUG_TYPE "vector-broadcast-lowering"
56 LogicalResult matchAndRewrite(vector::GatherOp op,
58 Value indexVec = op.getIndices();
59 Value maskVec = op.getMask();
60 Value passThruVec = op.getPassThru();
67 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
69 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
70 Value passThruSubVec =
71 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
72 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
73 op.getOffsets(), indexSubVec, maskSubVec,
74 passThruSubVec, op.getAlignmentAttr());
104 LogicalResult matchAndRewrite(vector::GatherOp op,
106 Value base = op.getBase();
113 auto sourceType = subview.getSource().getType();
116 if (sourceType.getRank() != 2)
120 auto layout = subview.getResult().getType().getLayout();
121 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
122 if (!stridedLayoutAttr)
126 if (stridedLayoutAttr.getStrides().size() != 1)
129 int64_t srcTrailingDim = sourceType.getShape().back();
134 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
139 Value collapsed = memref::CollapseShapeOp::create(
140 rewriter, op.getLoc(), subview.getSource(), reassoc);
144 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
145 VectorType vType = op.getIndices().getType();
146 Value mulCst = arith::ConstantOp::create(
150 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
154 Value newGather = vector::GatherOp::create(
155 rewriter, op.getLoc(), op.getResult().
getType(), collapsed,
156 op.getOffsets(), newIdxs, op.getMask(), op.getPassThru(),
157 op.getAlignmentAttr());
177 LogicalResult matchAndRewrite(vector::GatherOp op,
179 VectorType resultTy = op.getType();
180 if (resultTy.getRank() != 1)
183 if (resultTy.isScalable())
187 Type elemTy = resultTy.getElementType();
189 VectorType elemVecTy = VectorType::get({1}, elemTy);
191 Value condMask = op.getMask();
192 Value base = op.getBase();
196 bool useDelinearization =
false;
197 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
200 if (
auto stridesAttr =
201 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
202 if (stridesAttr.getStrides().back() != 1 &&
203 resultTy.getNumElements() != 1)
205 op,
"most minor memref dim must have unit stride");
208 if (memType.getRank() > 1)
209 useDelinearization =
true;
215 auto loadOffsets = llvm::to_vector(op.getOffsets());
216 Value lastLoadOffset = loadOffsets.back();
221 Value linearizedOffsets;
222 if (useDelinearization) {
224 linearizedOffsets = affine::AffineLinearizeIndexOp::create(
225 rewriter, loc, loadOffsets, baseShape,
false);
230 IntegerAttr alignmentAttr = op.getAlignmentAttr();
233 for (
int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
236 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
237 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
239 if (useDelinearization) {
247 auto delinOp = affine::AffineDelinearizeIndexOp::create(
248 rewriter, loc, flatIdx, baseShape,
true);
249 for (
int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
250 loadOffsets[d] = delinOp.getResult(d);
258 if (isa<MemRefType>(base.
getType())) {
262 vector::LoadOp::create(
b, loc, elemVecTy, base, loadOffsets,
263 nontemporalAttr, alignmentAttr);
265 extracted = vector::ExtractOp::create(
b, loc,
load, zeroIdx);
267 extracted = tensor::ExtractOp::create(
b, loc, base, loadOffsets);
271 vector::InsertOp::create(
b, loc, extracted,
result, thisIdx);
272 scf::YieldOp::create(
b, loc, newResult);
275 scf::YieldOp::create(
b, loc,
result);
278 result = scf::IfOp::create(rewriter, loc, condition,
297 patterns.
add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...