42struct VectorContractToPackedTypeDotProduct
44 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
46 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
47 PatternRewriter &rewriter)
const override {
49 if (contractOp.getKind() != vector::CombiningKind::ADD)
51 "Expects add combining kind.");
53 VectorType lhsTy = contractOp.getLhsType();
54 if (!lhsTy.getElementType().isBF16() &&
55 !lhsTy.getElementType().isSignlessInteger(8))
57 contractOp,
"Only BF16/Int8 lowering is supported.");
59 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
61 contractOp.getIndexingMapsArray(), blockingFactor))
63 "Input matrices not in VNNI format.");
65 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
66 llvm::SmallVector<int64_t> nonUnitDimLhs;
67 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
68 [](int64_t dim) {
return dim != 1; });
70 VectorType rhsTy = contractOp.getRhsType();
71 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
72 llvm::SmallVector<int64_t> nonUnitDimRhs;
73 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
74 [](int64_t dim) {
return dim != 1; });
76 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
78 "Excepts unit dimensions for either "
79 "LHS or RHS shape other than VNNI.");
81 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
84 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
86 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
90 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
91 (lhsTy.getElementType().isSignlessInteger(8) &&
92 !accTy.getElementType().isSignlessInteger(32)))
94 "Only F32 for BF16 or Int32 for Int8 "
95 "accumulation type is supported.");
97 ArrayRef<int64_t> accShape = accTy.getShape();
98 llvm::SmallVector<int64_t> nonUnitDimAcc;
99 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
100 [](int64_t dim) {
return dim != 1; });
101 if (nonUnitDimAcc.size() != 1)
103 contractOp,
"A or B should be a non-unit dim in acc.");
107 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
108 : nonUnitDimRhs.front();
109 if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
110 nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
112 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
113 "RHS) dim and acc dim of size 4/8/16.");
115 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
116 nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
118 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
119 "RHS) dim and acc dim of size 4/8.");
121 auto loc = contractOp.getLoc();
122 auto castAcc = vector::ShapeCastOp::create(
124 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
125 contractOp.getAcc());
134 if ((nonUnitDimRhs.size() - 1) > 0) {
135 auto castRhs = vector::ShapeCastOp::create(
137 VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
138 rhsTy.getElementType()),
139 contractOp.getRhs());
140 auto castLhs = vector::ShapeCastOp::create(
142 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
143 contractOp.getLhs());
144 auto bitcastLhs = vector::BitCastOp::create(
147 auto broadcastLhs = vector::BroadcastOp::create(
149 VectorType::get({nonUnitDimRhs.front()}, rewriter.
getIntegerType(32)),
151 auto bitcastLhsPkType = vector::BitCastOp::create(
152 rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
154 if (lhsTy.getElementType().isBF16()) {
155 dp = x86vector::DotBF16Op::create(
157 VectorType::get(nonUnitDimRhs.front(), rewriter.
getF32Type()),
158 castAcc, bitcastLhsPkType, castRhs);
161 if (lhsTy.getElementType().isSignlessInteger(8)) {
162 dp = x86vector::DotInt8Op::create(
164 VectorType::get(nonUnitDimRhs.front(), rewriter.
getIntegerType(32)),
165 castAcc, bitcastLhsPkType, castRhs);
168 auto castLhs = vector::ShapeCastOp::create(
170 VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
171 lhsTy.getElementType()),
172 contractOp.getLhs());
173 auto castRhs = vector::ShapeCastOp::create(
175 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
176 contractOp.getRhs());
177 auto bitcastRhs = vector::BitCastOp::create(
180 auto broadcastRhs = vector::BroadcastOp::create(
182 VectorType::get({nonUnitDimLhs.front()}, rewriter.
getIntegerType(32)),
184 auto bitcastRhsPkType = vector::BitCastOp::create(
185 rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
187 if (lhsTy.getElementType().isBF16()) {
188 dp = x86vector::DotBF16Op::create(
190 VectorType::get(nonUnitDimLhs.front(), rewriter.
getF32Type()),
191 castAcc, castLhs, bitcastRhsPkType);
194 if (lhsTy.getElementType().isSignlessInteger(8)) {
195 dp = x86vector::DotInt8Op::create(
197 VectorType::get(nonUnitDimLhs.front(), rewriter.
getIntegerType(32)),
198 castAcc, castLhs, bitcastRhsPkType);
205 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
IntegerType getIntegerType(unsigned width)
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...