MLIR 22.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15
17#include "mlir/IR/Dominance.h"
19
20#include "mlir/Pass/Pass.h"
22
23using namespace mlir;
24using namespace mlir::vector;
25using namespace mlir::x86vector;
26
27namespace {
28
29static FailureOr<SmallVector<mlir::utils::IteratorType>>
31 if (!map.isProjectedPermutation())
32 return failure();
34 map.getNumDims(), mlir::utils::IteratorType::reduction);
35 for (auto expr : map.getResults())
36 if (auto dim = dyn_cast<AffineDimExpr>(expr))
37 iterators[dim.getPosition()] = mlir::utils::IteratorType::parallel;
38 return iterators;
39}
40
41// Returns true if the operation is in VNNI layout.
42// Optionally, the check can be constrained to a specific VNNI blocking factor.
43static bool isInVnniLayout(Operation *op, ArrayRef<AffineMap> indexingMaps,
44 std::optional<unsigned> blockingFactor) {
45 // Narrow down type operations - VNNI only applies to contractions.
46 FailureOr<linalg::ContractionDimensions> dims =
47 linalg::inferContractionDims(indexingMaps);
48 if (failed(dims))
49 return false;
50
51 auto matA = op->getOperand(0);
52 auto matB = op->getOperand(1);
53 auto typeA = dyn_cast<ShapedType>(matA.getType());
54 auto typeB = dyn_cast<ShapedType>(matB.getType());
55 unsigned rankA = typeA.getRank();
56 unsigned rankB = typeB.getRank();
57 // VNNI format requires at least 1 parallel and 2 reduction dimensions.
58 if (rankA < 3 || rankB < 3)
59 return false;
60
61 // At least two reduction dimensions are expected:
62 // one for the VNNI factor and one for the K dimension
63 if (dims->k.size() < 2)
64 return false;
65
66 // Validate affine maps - VNNI computation should be defined by the two
67 // innermost reduction iterators.
68 // The input matrix dimensions layout must match the following:
69 // - matrix A - [...][K/vnniFactor][vnniFactor]
70 // - matrix B - [...][K/vnniFactor][N][vnniFactor]
71 auto maybeIters = inferIteratorsFromOutMap(indexingMaps[2]);
72 if (failed(maybeIters))
73 return false;
74 SmallVector<mlir::utils::IteratorType> iteratorTypes = *maybeIters;
75 AffineMap mapA = indexingMaps[0];
76 AffineMap mapB = indexingMaps[1];
77
78 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 1));
79 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 1));
80 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
81 iteratorTypes[vnniDimA.getPosition()] !=
82 mlir::utils::IteratorType::reduction)
83 return false;
84 auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(rankA - 2));
85 auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 3));
86 if (!redDimA || !redDimB || redDimA != redDimB ||
87 iteratorTypes[redDimA.getPosition()] !=
88 mlir::utils::IteratorType::reduction)
89 return false;
90 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(rankB - 2));
91 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
92 mlir::utils::IteratorType::parallel)
93 return false;
94
95 // VNNI factor must be:
96 // - the innermost inputs' dimension
97 // - statically known
98 // - multiple of 2 or equal to the specified factor
99 auto vnniDimSize = typeB.getShape().back();
100 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
101 vnniDimSize % 2 != 0)
102 return false;
103 if (typeA.getShape().back() != vnniDimSize)
104 return false;
105 if (blockingFactor && vnniDimSize != *blockingFactor)
106 return false;
107
108 // The split reduction dimension size should also match.
109 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])
110 return false;
111
112 return true;
113}
114
115// Implements packed type outer product contraction as a sequence
116// of broadcast and packed dot-product operations.
117//
118// For example - for F32 type:
119// ```
120// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
121// ```
122// to
123// ```
124// vector.broadcast %lhs to <32xbf16>
125// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
126// ```
127struct VectorContractToPackedTypeDotProduct
128 : public OpRewritePattern<vector::ContractionOp> {
129 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
130
131 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
132 PatternRewriter &rewriter) const override {
133
134 if (contractOp.getKind() != vector::CombiningKind::ADD)
135 return rewriter.notifyMatchFailure(contractOp,
136 "Expects add combining kind.");
137
138 VectorType lhsTy = contractOp.getLhsType();
139 if (!lhsTy.getElementType().isBF16() &&
140 !lhsTy.getElementType().isSignlessInteger(8))
141 return rewriter.notifyMatchFailure(
142 contractOp, "Only BF16/Int8 lowering is supported.");
143
144 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
145 if (!isInVnniLayout(contractOp.getOperation(),
146 contractOp.getIndexingMapsArray(), blockingFactor))
147 return rewriter.notifyMatchFailure(contractOp,
148 "Input matrices not in VNNI format.");
149
150 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
151 llvm::SmallVector<int64_t> nonUnitDimLhs;
152 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
153 [](int64_t dim) { return dim != 1; });
154
155 VectorType rhsTy = contractOp.getRhsType();
156 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
157 llvm::SmallVector<int64_t> nonUnitDimRhs;
158 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
159 [](int64_t dim) { return dim != 1; });
160
161 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
162 return rewriter.notifyMatchFailure(contractOp,
163 "Excepts unit dimensions for either "
164 "LHS or RHS shape other than VNNI.");
165
166 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
167 return rewriter.notifyMatchFailure(
168 contractOp,
169 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
170
171 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
172 if (!accTy)
173 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
174
175 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
176 (lhsTy.getElementType().isSignlessInteger(8) &&
177 !accTy.getElementType().isSignlessInteger(32)))
178 return rewriter.notifyMatchFailure(contractOp,
179 "Only F32 for BF16 or Int32 for Int8 "
180 "accumulation type is supported.");
181
182 ArrayRef<int64_t> accShape = accTy.getShape();
183 llvm::SmallVector<int64_t> nonUnitDimAcc;
184 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
185 [](int64_t dim) { return dim != 1; });
186 if (nonUnitDimAcc.size() != 1)
187 return rewriter.notifyMatchFailure(
188 contractOp, "A or B should be a non-unit dim in acc.");
189
190 // Non-unit dimensions should match the vector length of BF16 or Int8
191 // dot-product.
192 unsigned int nonUnitDim = nonUnitDimLhs.size() == 2 ? nonUnitDimLhs.front()
193 : nonUnitDimRhs.front();
194 if (lhsTy.getElementType().isBF16() && nonUnitDim != 4 && nonUnitDim != 8 &&
195 nonUnitDim != 16 && nonUnitDimAcc.front() == nonUnitDim)
196 return rewriter.notifyMatchFailure(
197 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
198 "RHS) dim and acc dim of size 4/8/16.");
199
200 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDim != 4 &&
201 nonUnitDim != 8 && nonUnitDimAcc.front() == nonUnitDim)
202 return rewriter.notifyMatchFailure(
203 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
204 "RHS) dim and acc dim of size 4/8.");
205
206 auto loc = contractOp.getLoc();
207 auto castAcc = vector::ShapeCastOp::create(
208 rewriter, loc,
209 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
210 contractOp.getAcc());
211
212 Value dp;
213
214 // Broadcast the unit-dimension LHS or RHS to match the vector length of the
215 // corresponding non-unit dimension on the other operand. For example,
216 // if LHS has type vector<1x1x2xbf16> and RHS has type vector<1x16x2xbf16>,
217 // we broadcast the LHS to vector<16x2xbf16>. In the opposite case (non-unit
218 // dimension on the LHS), we broadcast the RHS instead.
219 if ((nonUnitDimRhs.size() - 1) > 0) {
220 auto castRhs = vector::ShapeCastOp::create(
221 rewriter, loc,
222 VectorType::get(nonUnitDimRhs.front() * nonUnitDimRhs.back(),
223 rhsTy.getElementType()),
224 contractOp.getRhs());
225 auto castLhs = vector::ShapeCastOp::create(
226 rewriter, loc,
227 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
228 contractOp.getLhs());
229 auto bitcastLhs = vector::BitCastOp::create(
230 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
231 castLhs);
232 auto broadcastLhs = vector::BroadcastOp::create(
233 rewriter, loc,
234 VectorType::get({nonUnitDimRhs.front()}, rewriter.getIntegerType(32)),
235 bitcastLhs);
236 auto bitcastLhsPkType = vector::BitCastOp::create(
237 rewriter, loc, castRhs.getResult().getType(), broadcastLhs);
238
239 if (lhsTy.getElementType().isBF16()) {
240 dp = x86vector::DotBF16Op::create(
241 rewriter, loc,
242 VectorType::get(nonUnitDimRhs.front(), rewriter.getF32Type()),
243 castAcc, bitcastLhsPkType, castRhs);
244 }
245
246 if (lhsTy.getElementType().isSignlessInteger(8)) {
247 dp = x86vector::DotInt8Op::create(
248 rewriter, loc,
249 VectorType::get(nonUnitDimRhs.front(), rewriter.getIntegerType(32)),
250 castAcc, bitcastLhsPkType, castRhs);
251 }
252 } else {
253 auto castLhs = vector::ShapeCastOp::create(
254 rewriter, loc,
255 VectorType::get(nonUnitDimLhs.front() * nonUnitDimLhs.back(),
256 lhsTy.getElementType()),
257 contractOp.getLhs());
258 auto castRhs = vector::ShapeCastOp::create(
259 rewriter, loc,
260 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
261 contractOp.getRhs());
262 auto bitcastRhs = vector::BitCastOp::create(
263 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
264 castRhs);
265 auto broadcastRhs = vector::BroadcastOp::create(
266 rewriter, loc,
267 VectorType::get({nonUnitDimLhs.front()}, rewriter.getIntegerType(32)),
268 bitcastRhs);
269 auto bitcastRhsPkType = vector::BitCastOp::create(
270 rewriter, loc, castLhs.getResult().getType(), broadcastRhs);
271
272 if (lhsTy.getElementType().isBF16()) {
273 dp = x86vector::DotBF16Op::create(
274 rewriter, loc,
275 VectorType::get(nonUnitDimLhs.front(), rewriter.getF32Type()),
276 castAcc, castLhs, bitcastRhsPkType);
277 }
278
279 if (lhsTy.getElementType().isSignlessInteger(8)) {
280 dp = x86vector::DotInt8Op::create(
281 rewriter, loc,
282 VectorType::get(nonUnitDimLhs.front(), rewriter.getIntegerType(32)),
283 castAcc, castLhs, bitcastRhsPkType);
284 }
285 }
286
287 if (!dp)
288 return failure();
289
290 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
291 rewriter.replaceOp(contractOp, castDp);
292 return success();
293 }
294};
295
296} // namespace
297
300 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
301}
return success()
static FailureOr< SmallVector< utils::IteratorType > > inferIteratorsFromOutMap(AffineMap map)
Infer the iterator types from the init affine map.
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
AffineExpr getResult(unsigned idx) const
FloatType getF32Type()
Definition Builders.cpp:43
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...