42 unsigned vectorLength;
43 bool enableVLAVectorization;
44 bool enableSIMDIndex32;
48 static bool isInvariantValue(
Value val,
Block *block) {
58 static VectorType vectorType(VL vl,
Type etp) {
59 return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization);
63 static VectorType vectorType(VL vl,
Value mem) {
70 VectorType mtp = vectorType(vl, rewriter.
getI1Type());
75 IntegerAttr loInt, hiInt, stepInt;
79 if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) {
81 return rewriter.
create<vector::BroadcastOp>(loc, mtp, trueVal);
95 return rewriter.
create<vector::CreateMaskOp>(loc, mtp, end);
102 VectorType vtp = vectorType(vl, val.
getType());
103 return rewriter.
create<vector::BroadcastOp>(val.
getLoc(), vtp, val);
112 VectorType vtp = vectorType(vl, mem);
114 if (llvm::isa<VectorType>(idxs.back().getType())) {
116 Value indexVec = idxs.back();
118 return rewriter.
create<vector::GatherOp>(loc, vtp, mem, scalarArgs,
119 indexVec, vmask, pass);
121 return rewriter.
create<vector::MaskedLoadOp>(loc, vtp, mem, idxs, vmask,
131 if (llvm::isa<VectorType>(idxs.back().getType())) {
133 Value indexVec = idxs.back();
135 rewriter.
create<vector::ScatterOp>(loc, mem, scalarArgs, indexVec, vmask,
139 rewriter.
create<vector::MaskedStoreOp>(loc, mem, idxs, vmask, rhs);
144 static bool isVectorizableReduction(
Value red,
Value iter,
145 vector::CombiningKind &kind) {
147 kind = vector::CombiningKind::ADD;
148 return addf->getOperand(0) == iter || addf->getOperand(1) == iter;
151 kind = vector::CombiningKind::ADD;
152 return addi->getOperand(0) == iter || addi->getOperand(1) == iter;
155 kind = vector::CombiningKind::ADD;
156 return subf->getOperand(0) == iter;
159 kind = vector::CombiningKind::ADD;
160 return subi->getOperand(0) == iter;
163 kind = vector::CombiningKind::MUL;
164 return mulf->getOperand(0) == iter || mulf->getOperand(1) == iter;
167 kind = vector::CombiningKind::MUL;
168 return muli->getOperand(0) == iter || muli->getOperand(1) == iter;
171 kind = vector::CombiningKind::AND;
172 return andi->getOperand(0) == iter || andi->getOperand(1) == iter;
175 kind = vector::CombiningKind::OR;
176 return ori->getOperand(0) == iter || ori->getOperand(1) == iter;
179 kind = vector::CombiningKind::XOR;
180 return xori->getOperand(0) == iter || xori->getOperand(1) == iter;
193 vector::CombiningKind kind;
194 if (!isVectorizableReduction(red, iter, kind))
195 llvm_unreachable(
"unknown reduction");
197 case vector::CombiningKind::ADD:
198 case vector::CombiningKind::XOR:
200 return rewriter.
create<vector::InsertElementOp>(
203 case vector::CombiningKind::MUL:
205 return rewriter.
create<vector::InsertElementOp>(
208 case vector::CombiningKind::AND:
209 case vector::CombiningKind::OR:
211 return rewriter.
create<vector::BroadcastOp>(loc, vtp, r);
215 llvm_unreachable(
"unknown reduction kind");
237 static bool vectorizeSubscripts(
PatternRewriter &rewriter, scf::ForOp forOp,
241 unsigned dim = subs.size();
243 for (
auto sub : subs) {
244 bool innermost = ++d == dim;
250 if (isInvariantValue(sub, block)) {
262 if (
auto arg = llvm::dyn_cast<BlockArgument>(sub)) {
263 if (isInvariantArg(arg, block) == innermost)
272 if (
auto icast = cast.getDefiningOp<arith::IndexCastOp>())
273 cast = icast->getOperand(0);
274 else if (
auto ecast = cast.getDefiningOp<arith::ExtUIOp>())
275 cast = ecast->getOperand(0);
292 if (
auto load = cast.getDefiningOp<memref::LoadOp>()) {
299 genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs2, vmask);
300 Type etp = llvm::cast<VectorType>(vload.
getType()).getElementType();
301 if (!llvm::isa<IndexType>(etp)) {
303 vload = rewriter.
create<arith::ExtUIOp>(
304 loc, vectorType(vl, rewriter.
getI32Type()), vload);
306 vload = rewriter.
create<arith::ExtUIOp>(
307 loc, vectorType(vl, rewriter.
getI64Type()), vload);
309 idxs.push_back(vload);
316 if (
auto load = cast.getDefiningOp<arith::AddIOp>()) {
317 Value inv = load.getOperand(0);
318 Value idx = load.getOperand(1);
319 if (isInvariantValue(inv, block)) {
320 if (
auto arg = llvm::dyn_cast<BlockArgument>(idx)) {
321 if (isInvariantArg(arg, block) || !innermost)
325 rewriter.
create<arith::AddIOp>(forOp.getLoc(), inv, idx));
336 if (isa<xxx>(def)) { \
338 vexp = rewriter.create<xxx>(loc, vx); \
342 #define TYPEDUNAOP(xxx) \
343 if (auto x = dyn_cast<xxx>(def)) { \
345 VectorType vtp = vectorType(vl, x.getType()); \
346 vexp = rewriter.create<xxx>(loc, vtp, vx); \
352 if (isa<xxx>(def)) { \
354 vexp = rewriter.create<xxx>(loc, vx, vy); \
364 static bool vectorizeExpr(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
368 if (!VectorType::isValidElementType(exp.
getType()))
371 if (
auto arg = llvm::dyn_cast<BlockArgument>(exp)) {
372 if (arg == forOp.getInductionVar()) {
376 VectorType vtp = vectorType(vl, arg.
getType());
377 Value veci = rewriter.
create<vector::BroadcastOp>(loc, vtp, arg);
379 if (vl.enableVLAVectorization) {
381 Value stepv = rewriter.
create<LLVM::StepVectorOp>(loc, stepvty);
382 incr = rewriter.
create<arith::IndexCastOp>(loc, vtp, stepv);
385 for (
unsigned i = 0, l = vl.vectorLength; i < l; i++)
386 integers.push_back(APInt(64, i));
388 incr = rewriter.
create<arith::ConstantOp>(loc, vtp, values);
390 vexp = rewriter.
create<arith::AddIOp>(loc, veci, incr);
398 vexp = genVectorInvariantValue(rewriter, vl, exp);
406 vexp = genVectorInvariantValue(rewriter, vl, exp);
415 if (
auto load = dyn_cast<memref::LoadOp>(def)) {
416 auto subs = load.getIndices();
418 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs)) {
420 vexp = genVectorLoad(rewriter, loc, vl, load.getMemRef(), idxs, vmask);
434 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
461 if (vectorizeExpr(rewriter, forOp, vl, def->
getOperand(0), codegen, vmask,
463 vectorizeExpr(rewriter, forOp, vl, def->
getOperand(1), codegen, vmask,
469 if (isa<arith::ShLIOp>(def) || isa<arith::ShRUIOp>(def) ||
470 isa<arith::ShRSIOp>(def)) {
472 if (!isInvariantValue(shiftFactor, block))
479 BINOP(arith::DivSIOp)
480 BINOP(arith::DivUIOp)
489 BINOP(arith::ShRUIOp)
490 BINOP(arith::ShRSIOp)
505 static bool vectorizeStmt(
PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
516 scf::YieldOp yield = cast<scf::YieldOp>(block.
getTerminator());
517 auto &last = *++block.
rbegin();
531 if (vl.enableVLAVectorization) {
534 step = rewriter.
create<arith::MulIOp>(loc, vscale, step);
536 if (!yield.getResults().empty()) {
537 Value init = forOp.getInitArgs()[0];
538 VectorType vtp = vectorType(vl, init.
getType());
539 Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
540 forOp.getRegionIterArg(0), init, vtp);
541 forOpNew = rewriter.
create<scf::ForOp>(
542 loc, forOp.getLowerBound(), forOp.getUpperBound(), step, vinit);
551 vmask = genVectorMask(rewriter, loc, vl, forOp.getInductionVar(),
552 forOp.getLowerBound(), forOp.getUpperBound(), step);
557 if (!yield.getResults().empty()) {
559 if (yield->getNumOperands() != 1)
561 Value red = yield->getOperand(0);
562 Value iter = forOp.getRegionIterArg(0);
563 vector::CombiningKind kind;
565 if (isVectorizableReduction(red, iter, kind) &&
566 vectorizeExpr(rewriter, forOp, vl, red, codegen, vmask, vrhs)) {
568 Value partial = forOpNew.getResult(0);
569 Value vpass = genVectorInvariantValue(rewriter, vl, iter);
570 Value vred = rewriter.
create<arith::SelectOp>(loc, vmask, vrhs, vpass);
571 rewriter.
create<scf::YieldOp>(loc, vred);
573 Value vres = rewriter.
create<vector::ReductionOp>(loc, kind, partial);
579 forOpNew.getInductionVar());
581 forOpNew.getRegionIterArg(0));
586 }
else if (
auto store = dyn_cast<memref::StoreOp>(last)) {
588 auto subs = store.getIndices();
590 Value rhs = store.getValue();
592 if (vectorizeSubscripts(rewriter, forOp, vl, subs, codegen, vmask, idxs) &&
593 vectorizeExpr(rewriter, forOp, vl, rhs, codegen, vmask, vrhs)) {
595 genVectorStore(rewriter, loc, store.getMemRef(), idxs, vmask, vrhs);
602 assert(!codegen &&
"cannot call codegen when analysis failed");
611 ForOpRewriter(
MLIRContext *context,
unsigned vectorLength,
612 bool enableVLAVectorization,
bool enableSIMDIndex32)
614 enableSIMDIndex32} {}
625 if (vectorizeStmt(rewriter, op, vl,
false) &&
626 vectorizeStmt(rewriter, op, vl,
true))
640 template <
typename VectorOp>
647 Value inp = op.getSource();
649 if (
auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
651 rewriter.
replaceOp(op, redOp.getVector());
668 unsigned vectorLength,
669 bool enableVLAVectorization,
670 bool enableSIMDIndex32) {
671 assert(vectorLength > 0);
672 patterns.
add<ForOpRewriter>(patterns.
getContext(), vectorLength,
673 enableVLAVectorization, enableSIMDIndex32);
674 patterns.
add<ReducChainRewriter<vector::InsertElementOp>,
675 ReducChainRewriter<vector::BroadcastOp>>(patterns.
getContext());
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
This class represents an argument of a Block.
Block * getOwner() const
Returns the block that owns this argument.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
OpListType & getOperations()
reverse_iterator rbegin()
AffineExpr getAffineSymbolExpr(unsigned position)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
unsigned getNumOperands()
Block * getBlock()
Returns the operation block that contains this operation.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
bool hasOneBlock()
Return true if this region has exactly one block.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName()
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Value constantZero(OpBuilder &builder, Location loc, Type tp)
Generates a 0-valued constant of the given type.
Value constantOne(OpBuilder &builder, Location loc, Type tp)
Generates a 1-valued constant of the given type.
Value constantI1(OpBuilder &builder, Location loc, bool b)
Generates a constant of i1 type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
void populateSparseVectorizationPatterns(RewritePatternSet &patterns, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32)
Populates the given patterns list with vectorization rules.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...