43 unsigned vectorLength;
44 bool enableVLAVectorization;
45 bool enableSIMDIndex32;
49 static bool isInvariantValue(
Value val,
Block *block) {
59 static VectorType vectorType(VL vl,
Type etp) {
60 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
64 static VectorType vectorType(VL vl,
Value mem) {
71 VectorType mtp = vectorType(vl, rewriter.
getI1Type());
76 IntegerAttr loInt, hiInt, stepInt;
80 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
82 return rewriter.
create<vector::BroadcastOp>(loc, mtp, trueVal);
96 return rewriter.
create<vector::CreateMaskOp>(loc, mtp, end);
103 VectorType vtp = vectorType(vl, val.
getType());
104 return rewriter.
create<vector::BroadcastOp>(val.
getLoc(), vtp, val);
113 VectorType vtp = vectorType(vl, mem);
115 if (llvm::isa<VectorType>(idxs.back().getType())) {
117 Value indexVec = idxs.back();
119 return rewriter.
create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
120 indexVec, vmask, pass);
122 return rewriter.
create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
132 if (llvm::isa<VectorType>(idxs.back().getType())) {
134 Value indexVec = idxs.back();
136 rewriter.
create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
140 rewriter.
create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
145 static bool isVectorizableReduction(
Value red,
Value iter,
146 vector::CombiningKind &kind) {
148 kind = vector::CombiningKind::ADD;
149 return addf->
getOperand(0) == iter || addf->getOperand(1) == iter;
152 kind = vector::CombiningKind::ADD;
153 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
156 kind = vector::CombiningKind::ADD;
157 return subf->getOperand(0) == iter;
160 kind = vector::CombiningKind::ADD;
161 return subi->getOperand(0) == iter;
164 kind = vector::CombiningKind::MUL;
165 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
168 kind = vector::CombiningKind::MUL;
169 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
172 kind = vector::CombiningKind::AND;
173 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
176 kind = vector::CombiningKind::OR;
177 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
180 kind = vector::CombiningKind::XOR;
181 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
194 vector::CombiningKind kind;
195 if (!isVectorizableReduction(red, iter, kind))
196 llvm_unreachable(
"unknown reduction");
198 case vector::CombiningKind::ADD:
199 case vector::CombiningKind::XOR:
201 return rewriter.
create<vector::InsertElementOp>(
204 case vector::CombiningKind::MUL:
206 return rewriter.
create<vector::InsertElementOp>(
209 case vector::CombiningKind::AND:
210 case vector::CombiningKind::OR:
212 return rewriter.
create<vector::BroadcastOp>(loc, vtp, r);
216 llvm_unreachable(
"unknown reduction kind");
238 static bool vectorizeSubscripts(
PatternRewriter &rewriter, scf::ForOp forOp,
242 unsigned dim = subs.size();
244 for (
auto sub : subs) {
245 bool innermost = ++d == dim;
251 if (isInvariantValue(sub, block)) {
263 if (
auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
264 if (isInvariantArg(arg, block) == innermost)
273 if (
auto icast = cast.getDefiningOp<arith::IndexCastOp>())
274 cast = icast->getOperand(0);
275 else if (
auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
276 cast = ecast->getOperand(0);
293 if (
auto load = cast.getDefiningOp<memref::LoadOp>()) {
300 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
301 Type etp = llvm::cast<VectorType>(vload.
getType()).getElementType();
302 if (!llvm::isa<IndexType>(etp)) {
304 vload = rewriter.
create<arith::ExtUIOp>(
305 loc, vectorType(vl, rewriter.
getI32Type()), vload);
307 vload = rewriter.
create<arith::ExtUIOp>(
308 loc, vectorType(vl, rewriter.
getI64Type()), vload);
310 idxs.push_back(vload);
317 if (
auto load = cast.getDefiningOp<arith::AddIOp>()) {
318 Value inv = load.getOperand(0);
319 Value idx = load.getOperand(1);
321 if (!isInvariantValue(inv, block)) {
323 idx = load.getOperand(0);
326 if (isInvariantValue(inv, block)) {
327 if (
auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
328 if (isInvariantArg(arg, block) || !innermost)
332 rewriter.
create<arith::AddIOp>(forOp.getLoc(), inv, idx));
343 if (isa<xxx>(def)) { \
345 vexp = rewriter.create<xxx>(loc, vx); \
349 #define TYPEDUNAOP(xxx) \
350 if (auto x = dyn_cast<xxx>(def)) { \
352 VectorType vtp = vectorType(vl, x.getType()); \
353 vexp = rewriter.create<xxx>(loc, vtp, vx); \
359 if (isa<xxx>(def)) { \
361 vexp = rewriter.create<xxx>(loc, vx, vy); \
371 static bool vectorizeExpr(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
375 if (!VectorType::isValidElementType(exp.
getType()))
378 if (
auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
379 if (arg == forOp.getInductionVar()) {
383 VectorType vtp = vectorType(vl, arg.
getType());
384 Value veci = rewriter.
create<vector::BroadcastOp>(loc, vtp, arg);
385 Value incr = rewriter.
create<vector::StepOp>(loc, vtp);
386 vexp = rewriter.
create<arith::AddIOp>(loc, veci, incr);
394 vexp = genVectorInvariantValue(rewriter, vl, exp);
402 vexp = genVectorInvariantValue(rewriter, vl, exp);
411 if (
auto load = dyn_cast<memref::LoadOp>(def)) {
412 auto subs = load.getIndices();
414 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
416 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
430 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
457 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
459 vectorizeExpr(rewriter, forOp, vl, def->
getOperand(1), codegen, vmask,
465 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
466 isa<arith::ShRSIOp>(def)) {
468 if (!isInvariantValue(shiftFactor, block))
475 BINOP(arith::DivSIOp)
476 BINOP(arith::DivUIOp)
485 BINOP(arith::ShRUIOp)
486 BINOP(arith::ShRSIOp)
501 static bool vectorizeStmt(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
512 scf::YieldOp yield = cast<scf::YieldOp>(block.
getTerminator());
513 auto &last = *++block.
rbegin();
527 if (vl.enableVLAVectorization) {
530 step = rewriter.
create<arith::MulIOp>(loc, vscale, step);
532 if (!yield.getResults().empty()) {
533 Value init = forOp.getInitArgs()[0];
534 VectorType vtp = vectorType(vl, init.
getType());
535 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
536 forOp.getRegionIterArg(0), init, vtp);
537 forOpNew = rewriter.
create<scf::ForOp>(
538 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
547 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
548 forOp.getLowerBound(), forOp.getUpperBound(), step);
553 if (!yield.getResults().empty()) {
555 if (yield->getNumOperands() != 1)
557 Value red = yield->getOperand(0);
558 Value iter = forOp.getRegionIterArg(0);
559 vector::CombiningKind kind;
561 if (isVectorizableReduction(red, iter, kind) &&
562 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
564 Value partial = forOpNew.getResult(0);
565 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
566 Value vred = rewriter.
create<arith::SelectOp>(loc, vmask, vrhs, vpass);
567 rewriter.
create<scf::YieldOp>(loc, vred);
569 Value vres = rewriter.
create<vector::ReductionOp>(loc, kind, partial);
575 forOpNew.getInductionVar());
577 forOpNew.getRegionIterArg(0));
582 }
else if (
auto store = dyn_cast<memref::StoreOp>(last)) {
584 auto subs = store.getIndices();
586 Value rhs = store.getValue();
588 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
589 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
591 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
598 assert(!codegen &&
"cannot call codegen when analysis failed");
607 ForOpRewriter(
MLIRContext *context,
unsigned vectorLength,
608 bool enableVLAVectorization,
bool enableSIMDIndex32)
610 enableSIMDIndex32} {}
612 LogicalResult matchAndRewrite(scf::ForOp op,
621 if (vectorizeStmt(rewriter, op, vl,
false) &&
622 vectorizeStmt(rewriter, op, vl,
true))
636 template <
typename VectorOp>
641 LogicalResult matchAndRewrite(VectorOp op,
643 Value inp = op.getSource();
645 if (
auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
647 rewriter.
replaceOp(op, redOp.getVector());
664 unsigned vectorLength,
665 bool enableVLAVectorization,
666 bool enableSIMDIndex32) {
667 assert(vectorLength > 0);
669 patterns.
add<ForOpRewriter>(patterns.
getContext(), vectorLength,
670 enableVLAVectorization, enableSIMDIndex32);
671 patterns.
add<ReducChainRewriter<vector::InsertElementOp>,
672 ReducChainRewriter<vector::BroadcastOp>>(patterns.
getContext());
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
reverse_iterator rbegin()
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...