42 unsigned vectorLength;
43 bool enableVLAVectorization;
44 bool enableSIMDIndex32;
48 static bool isInvariantValue(
Value val,
Block *block) {
58 static VectorType vectorType(VL vl,
Type etp) {
59 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
63 static VectorType vectorType(VL vl,
Value mem) {
70 VectorType mtp = vectorType(vl, rewriter.
getI1Type());
75 IntegerAttr loInt, hiInt, stepInt;
79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81 return vector::BroadcastOp::create(rewriter, loc, mtp, trueVal);
95 return vector::CreateMaskOp::create(rewriter, loc, mtp, end);
102 VectorType vtp = vectorType(vl, val.
getType());
103 return vector::BroadcastOp::create(rewriter, val.
getLoc(), vtp, val);
112 VectorType vtp = vectorType(vl, mem);
114 if (llvm::isa<VectorType>(idxs.back().getType())) {
116 Value indexVec = idxs.back();
118 return vector::GatherOp::create(rewriter, loc, vtp, mem, scalarArgs,
119 indexVec, vmask, pass);
121 return vector::MaskedLoadOp::create(rewriter, loc, vtp, mem, idxs, vmask,
131 if (llvm::isa<VectorType>(idxs.back().getType())) {
133 Value indexVec = idxs.back();
135 vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
139 vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);
144 static bool isVectorizableReduction(
Value red,
Value iter,
145 vector::CombiningKind &
kind) {
147 kind = vector::CombiningKind::ADD;
148 return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
151 kind = vector::CombiningKind::ADD;
152 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
155 kind = vector::CombiningKind::ADD;
156 return subf->getOperand(0) == iter;
159 kind = vector::CombiningKind::ADD;
160 return subi->getOperand(0) == iter;
163 kind = vector::CombiningKind::MUL;
164 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
167 kind = vector::CombiningKind::MUL;
168 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
171 kind = vector::CombiningKind::AND;
172 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
175 kind = vector::CombiningKind::OR;
176 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
179 kind = vector::CombiningKind::XOR;
180 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
193 vector::CombiningKind
kind;
194 if (!isVectorizableReduction(red, iter,
kind))
195 llvm_unreachable(
"unknown reduction");
197 case vector::CombiningKind::ADD:
198 case vector::CombiningKind::XOR:
200 return vector::InsertOp::create(rewriter, loc, r,
203 case vector::CombiningKind::MUL:
205 return vector::InsertOp::create(rewriter, loc, r,
208 case vector::CombiningKind::AND:
209 case vector::CombiningKind::OR:
211 return vector::BroadcastOp::create(rewriter, loc, vtp, r);
215 llvm_unreachable(
"unknown reduction kind");
237 static bool vectorizeSubscripts(
PatternRewriter &rewriter, scf::ForOp forOp,
241 unsigned dim = subs.size();
243 for (
auto sub : subs) {
244 bool innermost = ++d == dim;
250 if (isInvariantValue(sub, block)) {
262 if (
auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263 if (isInvariantArg(arg, block) == innermost)
272 if (
auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273 cast = icast->getOperand(0);
274 else if (
auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275 cast = ecast->getOperand(0);
292 if (
auto load = cast.getDefiningOp<memref::LoadOp>()) {
299 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300 Type etp = llvm::cast<VectorType>(vload.
getType()).getElementType();
301 if (!llvm::isa<IndexType>(etp)) {
303 vload = arith::ExtUIOp::create(
304 rewriter, loc, vectorType(vl, rewriter.
getI32Type()), vload);
306 vload = arith::ExtUIOp::create(
307 rewriter, loc, vectorType(vl, rewriter.
getI64Type()), vload);
309 idxs.push_back(vload);
316 if (
auto load = cast.getDefiningOp<arith::AddIOp>()) {
317 Value inv = load.getOperand(0);
318 Value idx = load.getOperand(1);
320 if (!isInvariantValue(inv, block)) {
322 idx = load.getOperand(0);
325 if (isInvariantValue(inv, block)) {
326 if (
auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327 if (isInvariantArg(arg, block) || !innermost)
331 arith::AddIOp::create(rewriter, forOp.getLoc(), inv, idx));
342 if (isa<xxx>(def)) { \
344 vexp = xxx::create(rewriter, loc, vx); \
348 #define TYPEDUNAOP(xxx) \
349 if (auto x = dyn_cast<xxx>(def)) { \
351 VectorType vtp = vectorType(vl, x.getType()); \
352 vexp = xxx::create(rewriter, loc, vtp, vx); \
358 if (isa<xxx>(def)) { \
360 vexp = xxx::create(rewriter, loc, vx, vy); \
370 static bool vectorizeExpr(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
374 if (!VectorType::isValidElementType(exp.
getType()))
377 if (
auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378 if (arg == forOp.getInductionVar()) {
382 VectorType vtp = vectorType(vl, arg.
getType());
383 Value veci = vector::BroadcastOp::create(rewriter, loc, vtp, arg);
384 Value incr = vector::StepOp::create(rewriter, loc, vtp);
385 vexp = arith::AddIOp::create(rewriter, loc, veci, incr);
393 vexp = genVectorInvariantValue(rewriter, vl, exp);
401 vexp = genVectorInvariantValue(rewriter, vl, exp);
410 if (
auto load = dyn_cast<memref::LoadOp>(def)) {
411 auto subs = load.getIndices();
413 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
415 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
429 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
456 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
458 vectorizeExpr(rewriter, forOp, vl, def->
getOperand(1), codegen, vmask,
464 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
465 isa<arith::ShRSIOp>(def)) {
467 if (!isInvariantValue(shiftFactor, block))
474 BINOP(arith::DivSIOp)
475 BINOP(arith::DivUIOp)
484 BINOP(arith::ShRUIOp)
485 BINOP(arith::ShRSIOp)
500 static bool vectorizeStmt(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
511 scf::YieldOp yield = cast<scf::YieldOp>(block.
getTerminator());
512 auto &last = *++block.
rbegin();
526 if (vl.enableVLAVectorization) {
528 vector::VectorScaleOp::create(rewriter, loc, rewriter.
getIndexType());
529 step = arith::MulIOp::create(rewriter, loc, vscale, step);
531 if (!yield.getResults().empty()) {
532 Value init = forOp.getInitArgs()[0];
533 VectorType vtp = vectorType(vl, init.
getType());
534 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
535 forOp.getRegionIterArg(0), init, vtp);
537 scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
538 forOp.getUpperBound(), step, vinit,
539 nullptr, forOp.getUnsignedCmp());
548 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
549 forOp.getLowerBound(), forOp.getUpperBound(), step);
554 if (!yield.getResults().empty()) {
556 if (yield->getNumOperands() != 1)
558 Value red = yield->getOperand(0);
559 Value iter = forOp.getRegionIterArg(0);
560 vector::CombiningKind
kind;
562 if (isVectorizableReduction(red, iter,
kind) &&
563 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
565 Value partial = forOpNew.getResult(0);
566 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
567 Value vred = arith::SelectOp::create(rewriter, loc, vmask, vrhs, vpass);
568 scf::YieldOp::create(rewriter, loc, vred);
570 Value vres = vector::ReductionOp::create(rewriter, loc,
kind, partial);
576 forOpNew.getInductionVar());
578 forOpNew.getRegionIterArg(0));
583 }
else if (
auto store = dyn_cast<memref::StoreOp>(last)) {
585 auto subs = store.getIndices();
587 Value rhs = store.getValue();
589 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
590 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
592 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
599 assert(!codegen &&
"cannot call codegen when analysis failed");
608 ForOpRewriter(
MLIRContext *context,
unsigned vectorLength,
609 bool enableVLAVectorization,
bool enableSIMDIndex32)
611 vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
613 LogicalResult matchAndRewrite(scf::ForOp op,
618 if (!op.getRegion().hasOneBlock() || !
isOneInteger(op.getStep()) ||
622 if (vectorizeStmt(rewriter, op, vl,
false) &&
623 vectorizeStmt(rewriter, op, vl,
true))
635 if (
auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
637 rewriter.
replaceOp(op, redOp.getVector());
650 struct ReducChainBroadcastRewriter
655 LogicalResult matchAndRewrite(vector::BroadcastOp op,
657 return cleanReducChain(rewriter, op, op.getSource());
666 struct ReducChainInsertRewriter :
public OpRewritePattern<vector::InsertOp> {
670 LogicalResult matchAndRewrite(vector::InsertOp op,
672 return cleanReducChain(rewriter, op, op.getValueToStore());
683 unsigned vectorLength,
684 bool enableVLAVectorization,
685 bool enableSIMDIndex32) {
686 assert(vectorLength > 0);
689 enableVLAVectorization, enableSIMDIndex32);
690 patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
static Type getElementType(Type type)
Determine the element type of type.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
reverse_iterator rbegin()
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...