29#define DEBUG_TYPE "vector-broadcast-lowering"
56 LogicalResult matchAndRewrite(vector::GatherOp op,
58 Value indexVec = op.getIndices();
59 Value maskVec = op.getMask();
60 Value passThruVec = op.getPassThru();
67 vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
69 vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
70 Value passThruSubVec =
71 vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
72 return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
73 op.getOffsets(), indexSubVec, maskSubVec,
74 passThruSubVec, op.getAlignmentAttr());
110 LogicalResult matchAndRewrite(vector::GatherOp op,
112 Value base = op.getBase();
119 auto sourceType = subview.getSource().getType();
122 if (sourceType.getRank() != 2)
126 auto layout = subview.getResult().getType().getLayout();
127 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
128 if (!stridedLayoutAttr)
132 if (stridedLayoutAttr.getStrides().size() != 1)
135 int64_t srcTrailingDim = sourceType.getShape().back();
140 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
148 int64_t subviewOffset = stridedLayoutAttr.getOffset();
149 if (ShapedType::isDynamic(subviewOffset))
154 Value collapsed = memref::CollapseShapeOp::create(
155 rewriter, op.getLoc(), subview.getSource(), reassoc);
163 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
164 VectorType vType = op.getIndices().getType();
165 Value mulCst = arith::ConstantOp::create(
168 arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst);
181 newOffsets.back() = rewriter.
createOrFold<arith::MulIOp>(
182 op.getLoc(), newOffsets.back(), strideVal);
183 Value subviewOffsetValue =
185 newOffsets.back() = rewriter.
createOrFold<arith::AddIOp>(
186 op.getLoc(), newOffsets.back(), subviewOffsetValue);
190 Value newGather = vector::GatherOp::create(
191 rewriter, op.getLoc(), op.getResult().
getType(), collapsed, newOffsets,
192 newIdxs, op.getMask(), op.getPassThru(), op.getAlignmentAttr());
212 LogicalResult matchAndRewrite(vector::GatherOp op,
214 VectorType resultTy = op.getType();
215 if (resultTy.getRank() != 1)
218 if (resultTy.isScalable())
222 Type elemTy = resultTy.getElementType();
224 VectorType elemVecTy = VectorType::get({1}, elemTy);
226 Value condMask = op.getMask();
227 Value base = op.getBase();
231 bool useDelinearization =
false;
232 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
235 if (
auto stridesAttr =
236 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
237 if (stridesAttr.getStrides().back() != 1 &&
238 resultTy.getNumElements() != 1)
240 op,
"most minor memref dim must have unit stride");
243 if (memType.getRank() > 1)
244 useDelinearization =
true;
250 auto loadOffsets = llvm::to_vector(op.getOffsets());
251 Value lastLoadOffset = loadOffsets.back();
256 Value linearizedOffsets;
257 if (useDelinearization) {
259 linearizedOffsets = affine::AffineLinearizeIndexOp::create(
260 rewriter, loc, loadOffsets, baseShape,
false);
265 IntegerAttr alignmentAttr = op.getAlignmentAttr();
268 for (
int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
271 vector::ExtractOp::create(rewriter, loc, condMask, thisIdx);
272 Value index = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
274 if (useDelinearization) {
282 auto delinOp = affine::AffineDelinearizeIndexOp::create(
283 rewriter, loc, flatIdx, baseShape,
true);
284 for (
int64_t d = 0, rank = loadOffsets.size(); d < rank; ++d)
285 loadOffsets[d] = delinOp.getResult(d);
293 if (isa<MemRefType>(base.
getType())) {
297 vector::LoadOp::create(
b, loc, elemVecTy, base, loadOffsets,
298 nontemporalAttr, alignmentAttr);
300 extracted = vector::ExtractOp::create(
b, loc,
load, zeroIdx);
302 extracted = tensor::ExtractOp::create(
b, loc, base, loadOffsets);
306 vector::InsertOp::create(
b, loc, extracted,
result, thisIdx);
307 scf::YieldOp::create(
b, loc, newResult);
310 scf::YieldOp::create(
b, loc,
result);
313 result = scf::IfOp::create(rewriter, loc, condition,
332 patterns.
add<RemoveStrideFromGatherSource, Gather1DToConditionalLoads>(
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
IntegerAttr getIndexAttr(int64_t value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorGatherToConditionalLoadPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
LogicalResult unrollVectorOp(Operation *op, PatternRewriter &rewriter, UnrollVectorOpFn unrollFn)
Include the generated interface declarations.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...