36 #define DEBUG_TYPE "vector-broadcast-lowering"
63 VectorType resultTy = op.getType();
64 if (resultTy.getRank() < 2)
68 Value indexVec = op.getIndexVec();
69 Value maskVec = op.getMask();
70 Value passThruVec = op.getPassThru();
76 resultTy.getElementType());
78 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
79 int64_t thisIdx[1] = {i};
82 rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
84 rewriter.
create<vector::ExtractOp>(loc, maskVec, thisIdx);
85 Value passThruSubVec =
86 rewriter.
create<vector::ExtractOp>(loc, passThruVec, thisIdx);
87 Value subGather = rewriter.
create<vector::GatherOp>(
88 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
91 rewriter.
create<vector::InsertOp>(loc, subGather, result, thisIdx);
123 Value base = op.getBase();
130 auto sourceType = subview.getSource().getType();
133 if (sourceType.getRank() != 2)
137 auto layout = subview.getResult().getType().getLayout();
138 auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139 if (!stridedLayoutAttr)
143 if (stridedLayoutAttr.getStrides().size() != 1)
146 int64_t srcTrailingDim = sourceType.getShape().back();
151 if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
156 Value collapsed = rewriter.
create<memref::CollapseShapeOp>(
157 op.
getLoc(), subview.getSource(), reassoc);
161 IntegerAttr stride = rewriter.
getIndexAttr(srcTrailingDim);
162 VectorType vType = op.getIndexVec().getType();
167 rewriter.
create<arith::MulIOp>(op.
getLoc(), op.getIndexVec(), mulCst);
171 Value newGather = rewriter.
create<vector::GatherOp>(
173 newIdxs, op.getMask(), op.getPassThru());
188 VectorType resultTy = op.getType();
189 if (resultTy.getRank() != 1)
193 Type elemTy = resultTy.getElementType();
197 Value condMask = op.getMask();
198 Value base = op.getBase();
201 if (
auto memType = dyn_cast<MemRefType>(base.
getType())) {
202 if (
auto stridesAttr =
203 dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
204 if (stridesAttr.getStrides().back() != 1)
212 auto baseOffsets = llvm::to_vector(op.getIndices());
213 Value lastBaseOffset = baseOffsets.back();
215 Value result = op.getPassThru();
218 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
219 int64_t thisIdx[1] = {i};
221 rewriter.
create<vector::ExtractOp>(loc, condMask, thisIdx);
222 Value index = rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
224 rewriter.
createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
228 if (isa<MemRefType>(base.
getType())) {
232 b.
create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
233 int64_t zeroIdx[1] = {0};
234 extracted = b.
create<vector::ExtractOp>(loc, load, zeroIdx);
236 extracted = b.
create<tensor::ExtractOp>(loc, base, baseOffsets);
240 b.
create<vector::InsertOp>(loc, extracted, result, thisIdx);
241 b.
create<scf::YieldOp>(loc, newResult);
244 b.
create<scf::YieldOp>(loc, result);
249 .
create<scf::IfOp>(loc, condition, loadBuilder,
262 patterns.
add<FlattenGather, RemoveStrideFromGatherSource,
263 Gather1DToConditionalLoads>(patterns.
getContext(), benefit);
IntegerAttr getIndexAttr(int64_t value)
TypedAttr getZeroAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...