MLIR 23.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86vector;
28
29namespace {
30
31// Returns true if the A or B matrix vector is packed (shuffled) to
32// VNNI layout, already.
33static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
34 if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
35 if (isa<vector::ShuffleOp>(defOp))
36 return true;
37
38 if (isa<vector::ShapeCastOp>(defOp)) {
39 Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
40 if (isa<vector::ShuffleOp>(defOpShpCst))
41 return true;
42 }
43 }
44
45 return false;
46}
47
48static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
49 mlir::Operation *targetContract,
50 mlir::PatternRewriter &rewriter) {
51 for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
52
53 mlir::Operation *user = use.getOwner();
54 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
55 mlir::isa<mlir::scf::ForOp>(user)) {
56 use.set(newVal);
57 }
58 }
59}
60
61// Function to convert the flat layout A or B matrix vector<32xbf16>
62// into VNNI packed layout using the vpunpack operations
63static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
64 mlir::Operation *opA,
65 mlir::Operation *opB,
66 mlir::vector::ContractionOp contractA,
67 mlir::vector::ContractionOp contractB,
68 int64_t nonUnitDimAcc,
69 mlir::VectorType Ty) {
70 mlir::Operation *insertAfter = opA->isBeforeInBlock(opB) ? opB : opA;
71
72 rewriter.setInsertionPointAfter(insertAfter);
73 mlir::Location loc = insertAfter->getLoc();
74
75 auto elemTy = Ty.getElementType();
76 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
77
78 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
79 opA->getResult(0));
80 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
81 opB->getResult(0));
82
83 static constexpr int64_t maskLo[] = {
84 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
85 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
86 static constexpr int64_t maskHi[] = {
87 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
88 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
89
90 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
91 castB, maskLo);
92 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
93 castB, maskHi);
94
95 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
96 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
97
98 rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
99 rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
100}
101
102// Implements packed type outer product contraction as a sequence
103// of broadcast and packed dot-product operations.
104//
105// For example - for bf16 type (VNNI):
106// ```
107// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
108// ```
109// to
110// ```
111// vector.broadcast %lhs to <32xbf16>
112// x86vector.avx512.dot vector<32xbf16> -> vector<16xf32>
113// ```
114//
115// For example - for bf16 type (Flat layout):
116// ```
117// %1 = vector.load -> <2x16xbf16>
118// %2 = vector.load -> <2x16xbf16>
119// vector.contract <1x2xbf16>, %1 into <1x16xf32>
120// vector.contract <1x2xbf16>, %2 into <1x16xf32>
121// ```
122// to
123// ```
124// %1 = vector.load -> <2x16xbf16>
125// %2 = vector.load -> <2x16xbf16>
126// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
127// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
128// vector.broadcast %lhs to <32xbf16>
129// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
130// vector.broadcast %lhs to <32xbf16>
131// x86vector.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
132// ```
133struct VectorContractToPackedTypeDotProduct
134 : public OpRewritePattern<vector::ContractionOp> {
135 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
136
137 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
138 PatternRewriter &rewriter) const override {
139
140 if (contractOp.getKind() != vector::CombiningKind::ADD)
141 return rewriter.notifyMatchFailure(contractOp,
142 "Expects add combining kind.");
143
144 VectorType lhsTy = contractOp.getLhsType();
145 if (!lhsTy.getElementType().isBF16() &&
146 !lhsTy.getElementType().isSignlessInteger(8))
147 return rewriter.notifyMatchFailure(
148 contractOp, "Only BF16/Int8 lowering is supported.");
149
150 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
151 bool isVnni =
152 isInVnniLayout(contractOp.getOperation(),
153 contractOp.getIndexingMapsArray(), blockingFactor);
154
155 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
156 return failure();
157
158 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
159 if (!accTy)
160 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
161
162 ArrayRef<int64_t> accShape = accTy.getShape();
163 llvm::SmallVector<int64_t> nonUnitDimAcc;
164 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
165 [](int64_t dim) { return dim != 1; });
166 if (nonUnitDimAcc.size() != 1)
167 return rewriter.notifyMatchFailure(
168 contractOp, "A or B should be a non-unit dim in acc.");
169
170 int64_t nonUnitDimValue = nonUnitDimAcc.front();
171 // Non-unit dimensions should match the vector length of BF16 or Int8
172 // dot-product.
173 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
174 nonUnitDimValue != 8 && nonUnitDimValue != 16)
175 return rewriter.notifyMatchFailure(
176 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
177 "RHS) dim and acc dim of size 4/8/16.");
178
179 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
180 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
181 nonUnitDimAcc.front() == nonUnitDimValue)
182 return rewriter.notifyMatchFailure(
183 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
184 "RHS) dim and acc dim of size 4/8/16.");
185
186 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
187 llvm::SmallVector<int64_t> nonUnitDimLhs;
188 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
189 [](int64_t dim) { return dim != 1; });
190
191 VectorType rhsTy = contractOp.getRhsType();
192 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
193 llvm::SmallVector<int64_t> nonUnitDimRhs;
194 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
195 [](int64_t dim) { return dim != 1; });
196
197 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
198 return rewriter.notifyMatchFailure(contractOp,
199 "Excepts unit dimensions for either "
200 "LHS or RHS shape.");
201
202 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
203 return rewriter.notifyMatchFailure(
204 contractOp,
205 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
206
207 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
208 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
209 : nonUnitDimRhs.front();
210
211 if (!isVnni && (extraFlatDim != blockingFactor))
212 return rewriter.notifyMatchFailure(
213 contractOp, "The K or reduction dim for flat layout should be 2.");
214
215 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
216 (lhsTy.getElementType().isSignlessInteger(8) &&
217 !accTy.getElementType().isSignlessInteger(32)))
218 return rewriter.notifyMatchFailure(contractOp,
219 "Only F32 for BF16 or Int32 for Int8 "
220 "accumulation type is supported.");
221
222 Value unitDimOperand =
223 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
224 Value nonUnitDimOperand =
225 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
226
227 // If the A or B matrix vector of the contact operation is not packed, then
228 // find it's pair contract operation and pack (shuffle) them to VNNI packed.
229 if (!isVnni) {
230 vector::ContractionOp pairContractOp;
231 Operation *nextOp = contractOp;
232 while ((nextOp = nextOp->getNextNode())) {
233 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
234
235 if (!contOp)
236 continue;
237
238 if (validatePairVectorContract(contractOp, contOp,
239 rhsHasMultipleNonUnitDims,
240 nonUnitDimValue)) {
241 pairContractOp = contOp;
242 break;
243 }
244 }
245
246 // If the accumulators are shuffled we get nullptr else the
247 // transfer_read or load operations.
248 Operation *accRead =
249 traceToVectorReadLikeParentOperation(contractOp.getAcc());
250
251 if (!pairContractOp &&
252 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
253 return rewriter.notifyMatchFailure(contractOp,
254 "Could not find a contract pair");
255
256 // Validate and shuffle the accumulator
257 if (accRead) {
258 // Trace back to the load or transfer_read operations of the contract
259 // accumulators.
260 Operation *accReadOp0 =
261 traceToVectorReadLikeParentOperation(contractOp.getAcc());
262 Operation *accReadOp1 =
263 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
264
265 // Iterate down to find the users of contact operations until it is
266 // store or transfer_write.
267 Operation *resultWriteOp0 =
268 traceToVectorWriteLikeUserOperation(contractOp.getResult());
269 Operation *resultWriteOp1 =
270 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
271
272 if (!accReadOp0 || !accReadOp1)
273 return rewriter.notifyMatchFailure(
274 contractOp,
275 "Operands doesn't have load or transfer_read as it's parent op");
276
277 if (!resultWriteOp0 || !resultWriteOp1)
278 return rewriter.notifyMatchFailure(
279 contractOp,
280 "The use of contract operations are neither vector.store "
281 "or transfer_write or has multiple users.");
282
283 if (contractOp->getBlock() == accReadOp1->getBlock() &&
284 contractOp->isBeforeInBlock(accReadOp1))
285 return rewriter.notifyMatchFailure(
286 contractOp,
287 "The load/read operation of pair contract operation is "
288 "after the contractOp");
289
290 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
291 resultWriteOp0->isBeforeInBlock(pairContractOp))
292 return rewriter.notifyMatchFailure(
293 contractOp, "The store/write operation of contract operation is "
294 "before the pair contract operation");
295 // Shuffle the accumulators of the contract operations.
296 LogicalResult readShuffle =
297 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
298 pairContractOp, nonUnitDimValue, accTy);
299
300 if (failed(readShuffle))
301 return rewriter.notifyMatchFailure(
302 contractOp, "Accumulator read is not by transfer_read or load");
303
304 // Shuffle the output of contract operations before it's use.
305 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
306 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
307
308 if (failed(writeShuffle))
309 return rewriter.notifyMatchFailure(
310 contractOp,
311 "Write to accumulator is not by transfer_write or store");
312 }
313
314 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
315 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
316 ? pairContractOp.getRhs()
317 : pairContractOp.getLhs();
318
319 // Get the non-packed A or B matrix's vector<32xbf16> elements.
320 Operation *nonUnitDimReadOp =
321 traceToVectorReadLikeParentOperation(nonUnitDimOperand);
322 Operation *nonUnitDimReadOpPairContract =
323 traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
324
325 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
326 return rewriter.notifyMatchFailure(
327 contractOp, "Could not find a valid contract pair");
328
329 if (contractOp->getBlock() ==
330 nonUnitDimReadOpPairContract->getBlock() &&
331 contractOp->isBeforeInBlock(nonUnitDimReadOpPairContract))
332 return rewriter.notifyMatchFailure(
333 contractOp,
334 "The load/read operation of pair contract operation is "
335 "after the contractOp");
336
337 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
338 ? contractOp.getRhsType()
339 : contractOp.getLhsType();
340
341 packNonUnitDimOperandToVNNI(
342 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
343 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
344 nonUnitDimTy);
345
346 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
347 : contractOp.getLhs();
348 }
349 }
350
351 rewriter.setInsertionPoint(contractOp);
352 auto loc = contractOp.getLoc();
353 auto castAcc = vector::ShapeCastOp::create(
354 rewriter, loc,
355 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
356 contractOp.getAcc());
357
358 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
359 ? contractOp.getRhsType()
360 : contractOp.getLhsType();
361 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
362 : contractOp.getRhsType();
363
364 Value dp;
365
366 auto castNonUnitDim = vector::ShapeCastOp::create(
367 rewriter, loc,
368 VectorType::get(blockingFactor * nonUnitDimValue,
369 nonUnitDimTy.getElementType()),
370 nonUnitDimOperand);
371
372 auto castUnitDim = vector::ShapeCastOp::create(
373 rewriter, loc,
374 VectorType::get(blockingFactor, unitDimTy.getElementType()),
375 unitDimOperand);
376 auto bitcastUnitDim = vector::BitCastOp::create(
377 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
378 castUnitDim);
379 auto broadcastUnitDim = vector::BroadcastOp::create(
380 rewriter, loc,
381 VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
382 bitcastUnitDim);
383 auto bitcastUnitDimPkType = vector::BitCastOp::create(
384 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
385
386 if (lhsTy.getElementType().isBF16()) {
387 dp = x86vector::DotBF16Op::create(
388 rewriter, loc,
389 VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
390 bitcastUnitDimPkType, castNonUnitDim);
391 }
392
393 if (lhsTy.getElementType().isSignlessInteger(8)) {
394 if (nonUnitDimAcc.front() == 16) {
395 dp = x86vector::AVX10DotInt8Op::create(
396 rewriter, loc,
397 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
398 castAcc, bitcastUnitDimPkType, castNonUnitDim);
399 } else {
400 dp = x86vector::DotInt8Op::create(
401 rewriter, loc,
402 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
403 castAcc, bitcastUnitDimPkType, castNonUnitDim);
404 }
405 }
406
407 if (!dp)
408 return failure();
409
410 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
411 rewriter.replaceOp(contractOp, castDp);
412 return success();
413 }
414};
415
416} // namespace
417
420 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
421}
return success()
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:350
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Operation * traceToVectorReadLikeParentOperation(Value v)
Operation * traceToVectorWriteLikeUserOperation(Value v)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...