36 #define DEBUG_TYPE "vector-broadcast-lowering"
63 VectorType resultTy = op.getType();
64 if (resultTy.getRank() < 2)
68 Value indexVec = op.getIndexVec();
69 Value maskVec = op.getMask();
70 Value passThruVec = op.getPassThru();
76 resultTy.getElementType());
78 for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
79 int64_t thisIdx[1] = {i};
82 rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
84 rewriter.
create<vector::ExtractOp>(loc, maskVec, thisIdx);
85 Value passThruSubVec =
86 rewriter.
create<vector::ExtractOp>(loc, passThruVec, thisIdx);
87 Value subGather = rewriter.
create<vector::GatherOp>(
88 loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
91 rewriter.
create<vector::InsertOp>(loc, subGather, result, thisIdx);
107 VectorType resultTy = op.getType();
108 if (resultTy.getRank() != 1)
112 Type elemTy = resultTy.getElementType();
116 Value condMask = op.getMask();
117 Value base = op.getBase();
121 auto baseOffsets = llvm::to_vector(op.getIndices());
122 Value lastBaseOffset = baseOffsets.back();
124 Value result = op.getPassThru();
127 for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
128 int64_t thisIdx[1] = {i};
130 rewriter.
create<vector::ExtractOp>(loc, condMask, thisIdx);
131 Value index = rewriter.
create<vector::ExtractOp>(loc, indexVec, thisIdx);
133 rewriter.
createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
137 if (isa<MemRefType>(base.
getType())) {
141 b.
create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
142 int64_t zeroIdx[1] = {0};
143 extracted = b.
create<vector::ExtractOp>(loc, load, zeroIdx);
145 extracted = b.
create<tensor::ExtractOp>(loc, base, baseOffsets);
149 b.
create<vector::InsertOp>(loc, extracted, result, thisIdx);
150 b.
create<scf::YieldOp>(loc, newResult);
153 b.
create<scf::YieldOp>(loc, result);
158 .
create<scf::IfOp>(loc, condition, loadBuilder,
171 patterns.
add<FlattenGather, Gather1DToConditionalLoads>(patterns.
getContext(),
TypedAttr getZeroAttr(Type type)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Location getLoc()
The source location the operation was defined or derived from.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...