MLIR 22.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16
18#include "mlir/IR/Dominance.h"
20
21#include "mlir/Pass/Pass.h"
23
24using namespace mlir;
25using namespace mlir::vector;
26using namespace mlir::x86vector;
27
28namespace {
29
30// Implements packed type outer product contraction as a sequence
31// of broadcast and packed dot-product operations.
32//
33// For example - for F32 type:
34// ```
35// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
36// ```
37// to
38// ```
39// vector.broadcast %lhs to <32xbf16>
40// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
41// ```
42struct VectorContractToPackedTypeDotProduct
43 : public OpRewritePattern<vector::ContractionOp> {
44 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
45
46 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
47 PatternRewriter &rewriter) const override {
48
49 if (contractOp.getKind() != vector::CombiningKind::ADD)
50 return rewriter.notifyMatchFailure(contractOp,
51 "Expects add combining kind.");
52
53 VectorType lhsTy = contractOp.getLhsType();
54 if (!lhsTy.getElementType().isBF16() &&
55 !lhsTy.getElementType().isSignlessInteger(8))
56 return rewriter.notifyMatchFailure(
57 contractOp, "Only BF16/Int8 lowering is supported.");
58
59 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
60 if (!isInVnniLayout(contractOp.getOperation(),
61 contractOp.getIndexingMapsArray(), blockingFactor))
62 return rewriter.notifyMatchFailure(contractOp,
63 "Input matrices not in VNNI format.");
64
65 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
66 llvm::SmallVector<int64_t> nonUnitDimLhs;
67 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
68 [](int64_t dim) { return dim != 1; });
69
70 VectorType rhsTy = contractOp.getRhsType();
71 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
72 llvm::SmallVector<int64_t> nonUnitDimRhs;
73 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
74 [](int64_t dim) { return dim != 1; });
75
76 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
77 return rewriter.notifyMatchFailure(contractOp,
78 "Excepts unit dimensions for either "
79 "LHS or RHS shape other than VNNI.");
80
81 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
82 return rewriter.notifyMatchFailure(
83 contractOp,
84 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
85
86 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
87 if (!accTy)
88 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
89
90 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
91 (lhsTy.getElementType().isSignlessInteger(8) &&
92 !accTy.getElementType().isSignlessInteger(32)))
93 return rewriter.notifyMatchFailure(contractOp,
94 "Only F32 for BF16 or Int32 for Int8 "
95 "accumulation type is supported.");
96
97 ArrayRef<int64_t> accShape = accTy.getShape();
98 llvm::SmallVector<int64_t> nonUnitDimAcc;
99 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
100 [](int64_t dim) { return dim != 1; });
101 if (nonUnitDimAcc.size() != 1)
102 return rewriter.notifyMatchFailure(
103 contractOp, "A or B should be a non-unit dim in acc.");
104
105 // Non-unit dimensions should match the vector length of BF16 or Int8
106 // dot-product.
107 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
108 : nonUnitDimRhs.front();
109 if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
110 nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
111 return rewriter.notifyMatchFailure(
112 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
113 "RHS) dim and acc dim of size 4/8/16.");
114
115 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
116 nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
117 return rewriter.notifyMatchFailure(
118 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
119 "RHS) dim and acc dim of size 4/8.");
120
121 auto loc = contractOp.getLoc();
122 auto castAcc = vector::ShapeCastOp::create(
123 rewriter, loc,
124 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
125 contractOp.getAcc());
126
127 Value dp;
128
129 // Broadcast the unit-dimension LHS or RHS to match the vector length of the
130 // corresponding non-unit dimension on the other operand. For example,
131 // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
132 // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
133 // dimension on the LHS), we broadcast the RHS instead.
134 if ((nonUnitDimRhs.size() - 1) > 0) {
135 auto castRhs = vector::ShapeCastOp::create(
136 rewriter, loc,
137 VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
138 rhsTy.getElementType()),
139 contractOp.getRhs());
140 auto castLhs = vector::ShapeCastOp::create(
141 rewriter, loc,
142 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
143 contractOp.getLhs());
144 auto bitcastLhs = vector::BitCastOp::create(
145 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
146 castLhs);
147 auto broadcastLhs = vector::BroadcastOp::create(
148 rewriter, loc,
149 VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
150 bitcastLhs);
151 auto bitcastLhsPkType = vector::BitCastOp::create(
152 rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
153
154 if (lhsTy.getElementType().isBF16()) {
155 dp = x86vector::DotBF16Op::create(
156 rewriter, loc,
157 VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
158 castAcc, bitcastLhsPkType, castRhs);
159 }
160
161 if (lhsTy.getElementType().isSignlessInteger(8)) {
162 dp = x86vector::DotInt8Op::create(
163 rewriter, loc,
164 VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
165 castAcc, bitcastLhsPkType, castRhs);
166 }
167 } else {
168 auto castLhs = vector::ShapeCastOp::create(
169 rewriter, loc,
170 VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
171 lhsTy.getElementType()),
172 contractOp.getLhs());
173 auto castRhs = vector::ShapeCastOp::create(
174 rewriter, loc,
175 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
176 contractOp.getRhs());
177 auto bitcastRhs = vector::BitCastOp::create(
178 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
179 castRhs);
180 auto broadcastRhs = vector::BroadcastOp::create(
181 rewriter, loc,
182 VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
183 bitcastRhs);
184 auto bitcastRhsPkType = vector::BitCastOp::create(
185 rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
186
187 if (lhsTy.getElementType().isBF16()) {
188 dp = x86vector::DotBF16Op::create(
189 rewriter, loc,
190 VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
191 castAcc, castLhs, bitcastRhsPkType);
192 }
193
194 if (lhsTy.getElementType().isSignlessInteger(8)) {
195 dp = x86vector::DotInt8Op::create(
196 rewriter, loc,
197 VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
198 castAcc, castLhs, bitcastRhsPkType);
199 }
200 }
201
202 if (!dp)
203 return failure();
204
205 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
206 rewriter.replaceOp(contractOp, castDp);
207 return success();
208 }
209};
210
211} // namespace
212
215 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
216}
return success()
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...