42 unsigned vectorLength;
43 bool enableVLAVectorization;
44 bool enableSIMDIndex32;
48 static bool isInvariantValue(
Value val,
Block *block) {
58 static VectorType vectorType(VL vl,
Type etp) {
59 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
63 static VectorType vectorType(VL vl,
Value mem) {
70 VectorType mtp = vectorType(vl, rewriter.
getI1Type());
75 IntegerAttr loInt, hiInt, stepInt;
79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81 return rewriter.
create<vector::BroadcastOp>(loc, mtp, trueVal);
95 return rewriter.
create<vector::CreateMaskOp>(loc, mtp, end);
102 VectorType vtp = vectorType(vl, val.
getType());
103 return rewriter.
create<vector::BroadcastOp>(val.
getLoc(), vtp, val);
112 VectorType vtp = vectorType(vl, mem);
114 if (llvm::isa<VectorType>(idxs.back().getType())) {
116 Value indexVec = idxs.back();
118 return rewriter.
create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
119 indexVec, vmask, pass);
121 return rewriter.
create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
131 if (llvm::isa<VectorType>(idxs.back().getType())) {
133 Value indexVec = idxs.back();
135 rewriter.
create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
139 rewriter.
create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
144 static bool isVectorizableReduction(
Value red,
Value iter,
145 vector::CombiningKind &kind) {
147 kind = vector::CombiningKind::ADD;
148 return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
151 kind = vector::CombiningKind::ADD;
152 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
155 kind = vector::CombiningKind::ADD;
156 return subf->getOperand(0) == iter;
159 kind = vector::CombiningKind::ADD;
160 return subi->getOperand(0) == iter;
163 kind = vector::CombiningKind::MUL;
164 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
167 kind = vector::CombiningKind::MUL;
168 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
171 kind = vector::CombiningKind::AND;
172 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
175 kind = vector::CombiningKind::OR;
176 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
179 kind = vector::CombiningKind::XOR;
180 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
193 vector::CombiningKind kind;
194 if (!isVectorizableReduction(red, iter, kind))
195 llvm_unreachable(
"unknown reduction");
197 case vector::CombiningKind::ADD:
198 case vector::CombiningKind::XOR:
200 return rewriter.
create<vector::InsertElementOp>(
203 case vector::CombiningKind::MUL:
205 return rewriter.
create<vector::InsertElementOp>(
208 case vector::CombiningKind::AND:
209 case vector::CombiningKind::OR:
211 return rewriter.
create<vector::BroadcastOp>(loc, vtp, r);
215 llvm_unreachable(
"unknown reduction kind");
237 static bool vectorizeSubscripts(
PatternRewriter &rewriter, scf::ForOp forOp,
241 unsigned dim = subs.size();
243 for (
auto sub : subs) {
244 bool innermost = ++d == dim;
250 if (isInvariantValue(sub, block)) {
262 if (
auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263 if (isInvariantArg(arg, block) == innermost)
272 if (
auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273 cast = icast->getOperand(0);
274 else if (
auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275 cast = ecast->getOperand(0);
292 if (
auto load = cast.getDefiningOp<memref::LoadOp>()) {
299 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300 Type etp = llvm::cast<VectorType>(vload.
getType()).getElementType();
301 if (!llvm::isa<IndexType>(etp)) {
303 vload = rewriter.
create<arith::ExtUIOp>(
304 loc, vectorType(vl, rewriter.
getI32Type()), vload);
306 vload = rewriter.
create<arith::ExtUIOp>(
307 loc, vectorType(vl, rewriter.
getI64Type()), vload);
309 idxs.push_back(vload);
316 if (
auto load = cast.getDefiningOp<arith::AddIOp>()) {
317 Value inv = load.getOperand(0);
318 Value idx = load.getOperand(1);
320 if (!isInvariantValue(inv, block)) {
322 idx = load.getOperand(0);
325 if (isInvariantValue(inv, block)) {
326 if (
auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
327 if (isInvariantArg(arg, block) || !innermost)
331 rewriter.
create<arith::AddIOp>(forOp.getLoc(), inv, idx));
342 if (isa<xxx>(def)) { \
344 vexp = rewriter.create<xxx>(loc, vx); \
348 #define TYPEDUNAOP(xxx) \
349 if (auto x = dyn_cast<xxx>(def)) { \
351 VectorType vtp = vectorType(vl, x.getType()); \
352 vexp = rewriter.create<xxx>(loc, vtp, vx); \
358 if (isa<xxx>(def)) { \
360 vexp = rewriter.create<xxx>(loc, vx, vy); \
370 static bool vectorizeExpr(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
374 if (!VectorType::isValidElementType(exp.
getType()))
377 if (
auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
378 if (arg == forOp.getInductionVar()) {
382 VectorType vtp = vectorType(vl, arg.
getType());
383 Value veci = rewriter.
create<vector::BroadcastOp>(loc, vtp, arg);
385 if (vl.enableVLAVectorization) {
387 Value stepv = rewriter.
create<LLVM::StepVectorOp>(loc, stepvty);
388 incr = rewriter.
create<arith::IndexCastOp>(loc, vtp, stepv);
391 for (
unsigned i = 0, l = vl.vectorLength; i < l; i++)
392 integers.push_back(APInt(64, i));
394 incr = rewriter.
create<arith::ConstantOp>(loc, vtp, values);
396 vexp = rewriter.
create<arith::AddIOp>(loc, veci, incr);
404 vexp = genVectorInvariantValue(rewriter, vl, exp);
412 vexp = genVectorInvariantValue(rewriter, vl, exp);
421 if (
auto load = dyn_cast<memref::LoadOp>(def)) {
422 auto subs = load.getIndices();
424 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
426 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
440 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
467 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
469 vectorizeExpr(rewriter, forOp, vl, def->
getOperand(1), codegen, vmask,
475 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
476 isa<arith::ShRSIOp>(def)) {
478 if (!isInvariantValue(shiftFactor, block))
485 BINOP(arith::DivSIOp)
486 BINOP(arith::DivUIOp)
495 BINOP(arith::ShRUIOp)
496 BINOP(arith::ShRSIOp)
511 static bool vectorizeStmt(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
522 scf::YieldOp yield = cast<scf::YieldOp>(block.
getTerminator());
523 auto &last = *++block.
rbegin();
537 if (vl.enableVLAVectorization) {
540 step = rewriter.
create<arith::MulIOp>(loc, vscale, step);
542 if (!yield.getResults().empty()) {
543 Value init = forOp.getInitArgs()[0];
544 VectorType vtp = vectorType(vl, init.
getType());
545 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
546 forOp.getRegionIterArg(0), init, vtp);
547 forOpNew = rewriter.
create<scf::ForOp>(
548 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
557 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
558 forOp.getLowerBound(), forOp.getUpperBound(), step);
563 if (!yield.getResults().empty()) {
565 if (yield->getNumOperands() != 1)
567 Value red = yield->getOperand(0);
568 Value iter = forOp.getRegionIterArg(0);
569 vector::CombiningKind kind;
571 if (isVectorizableReduction(red, iter, kind) &&
572 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
574 Value partial = forOpNew.getResult(0);
575 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
576 Value vred = rewriter.
create<arith::SelectOp>(loc, vmask, vrhs, vpass);
577 rewriter.
create<scf::YieldOp>(loc, vred);
579 Value vres = rewriter.
create<vector::ReductionOp>(loc, kind, partial);
585 forOpNew.getInductionVar());
587 forOpNew.getRegionIterArg(0));
592 }
else if (
auto store = dyn_cast<memref::StoreOp>(last)) {
594 auto subs = store.getIndices();
596 Value rhs = store.getValue();
598 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
599 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
601 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
608 assert(!codegen &&
"cannot call codegen when analysis failed");
617 ForOpRewriter(
MLIRContext *context,
unsigned vectorLength,
618 bool enableVLAVectorization,
bool enableSIMDIndex32)
620 enableSIMDIndex32} {}
631 if (vectorizeStmt(rewriter, op, vl,
false) &&
632 vectorizeStmt(rewriter, op, vl,
true))
646 template <
typename VectorOp>
653 Value inp = op.getSource();
655 if (
auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
657 rewriter.
replaceOp(op, redOp.getVector());
674 unsigned vectorLength,
675 bool enableVLAVectorization,
676 bool enableSIMDIndex32) {
677 assert(vectorLength > 0);
678 patterns.
add<ForOpRewriter>(patterns.
getContext(), vectorLength,
679 enableVLAVectorization, enableSIMDIndex32);
680 patterns.
add<ReducChainRewriter<vector::InsertElementOp>,
681 ReducChainRewriter<vector::BroadcastOp>>(patterns.
getContext());
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
reverse_iterator rbegin()
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool hasOneBlock()
Return true if this region has exactly one block.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...