MLIR 23.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86;
28
29namespace {
30
31// Returns true if the A or B matrix vector is packed (shuffled) to
32// VNNI layout, already.
33static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
34 if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
35 if (isa<vector::ShuffleOp>(defOp))
36 return true;
37
38 if (isa<vector::ShapeCastOp>(defOp)) {
39 Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
40 if (isa<vector::ShuffleOp>(defOpShpCst))
41 return true;
42 }
43 }
44
45 return false;
46}
47
48static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
49 mlir::Operation *targetContract,
50 mlir::PatternRewriter &rewriter) {
51 for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
52 mlir::Operation *user = use.getOwner();
53 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
54 mlir::isa<mlir::scf::ForOp>(user)) {
55 rewriter.modifyOpInPlace(user, [&]() { use.set(newVal); });
56 }
57 }
58}
59
60// Function to convert the flat layout A or B matrix vector<32xbf16>
61// into VNNI packed layout using the vpunpack operations
62static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
63 mlir::Operation *opA,
64 mlir::Operation *opB,
65 mlir::vector::ContractionOp contractA,
66 mlir::vector::ContractionOp contractB,
67 int64_t nonUnitDimAcc,
68 mlir::VectorType Ty) {
69
70 bool opABeforeopB = opA->isBeforeInBlock(opB);
71
72 if (opABeforeopB)
73 rewriter.moveOpAfter(opB, opA);
74 else
75 rewriter.moveOpAfter(opA, opB);
76
77 mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
78
79 rewriter.setInsertionPointAfter(insertAfter);
80 mlir::Location loc = insertAfter->getLoc();
81
82 auto elemTy = Ty.getElementType();
83 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
84
85 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
86 opA->getResult(0));
87 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
88 opB->getResult(0));
89
90 static constexpr int64_t maskLo[] = {
91 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
92 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
93 static constexpr int64_t maskHi[] = {
94 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
95 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
96
97 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
98 castB, maskLo);
99 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
100 castB, maskHi);
101
102 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
103 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
104
105 rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
106 rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
107}
108
109// Implements packed type outer product contraction as a sequence
110// of broadcast and packed dot-product operations.
111//
112// For example - for bf16 type (VNNI):
113// ```
114// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
115// ```
116// to
117// ```
118// vector.broadcast %lhs to <32xbf16>
119// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
120// ```
121//
122// For example - for bf16 type (Flat layout):
123// ```
124// %1 = vector.load -> <2x16xbf16>
125// %2 = vector.load -> <2x16xbf16>
126// vector.contract <1x2xbf16>, %1 into <1x16xf32>
127// vector.contract <1x2xbf16>, %2 into <1x16xf32>
128// ```
129// to
130// ```
131// %1 = vector.load -> <2x16xbf16>
132// %2 = vector.load -> <2x16xbf16>
133// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
134// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
135// vector.broadcast %lhs to <32xbf16>
136// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
137// vector.broadcast %lhs to <32xbf16>
138// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
139// ```
140struct VectorContractToPackedTypeDotProduct
141 : public OpRewritePattern<vector::ContractionOp> {
142 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
143
144 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
145 PatternRewriter &rewriter) const override {
146
147 if (contractOp.getKind() != vector::CombiningKind::ADD)
148 return rewriter.notifyMatchFailure(contractOp,
149 "Expects add combining kind.");
150
151 VectorType lhsTy = contractOp.getLhsType();
152 if (!lhsTy.getElementType().isBF16() &&
153 !lhsTy.getElementType().isSignlessInteger(8))
154 return rewriter.notifyMatchFailure(
155 contractOp, "Only BF16/Int8 lowering is supported.");
156
157 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
158 bool isVnni =
159 isInVnniLayout(contractOp.getOperation(),
160 contractOp.getIndexingMapsArray(), blockingFactor);
161
162 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
163 return failure();
164
165 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
166 if (!accTy)
167 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
168
169 ArrayRef<int64_t> accShape = accTy.getShape();
170 llvm::SmallVector<int64_t> nonUnitDimAcc;
171 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
172 [](int64_t dim) { return dim != 1; });
173 if (nonUnitDimAcc.size() != 1)
174 return rewriter.notifyMatchFailure(
175 contractOp, "A or B should be a non-unit dim in acc.");
176
177 int64_t nonUnitDimValue = nonUnitDimAcc.front();
178 // Non-unit dimensions should match the vector length of BF16 or Int8
179 // dot-product.
180 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
181 nonUnitDimValue != 8 && nonUnitDimValue != 16)
182 return rewriter.notifyMatchFailure(
183 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
184 "RHS) dim and acc dim of size 4/8/16.");
185
186 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
187 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
188 nonUnitDimAcc.front() == nonUnitDimValue)
189 return rewriter.notifyMatchFailure(
190 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
191 "RHS) dim and acc dim of size 4/8/16.");
192
193 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
194 llvm::SmallVector<int64_t> nonUnitDimLhs;
195 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
196 [](int64_t dim) { return dim != 1; });
197
198 VectorType rhsTy = contractOp.getRhsType();
199 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
200 llvm::SmallVector<int64_t> nonUnitDimRhs;
201 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
202 [](int64_t dim) { return dim != 1; });
203
204 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
205 return rewriter.notifyMatchFailure(contractOp,
206 "Excepts unit dimensions for either "
207 "LHS or RHS shape.");
208
209 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
210 return rewriter.notifyMatchFailure(
211 contractOp,
212 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
213
214 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
215 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
216 : nonUnitDimRhs.front();
217
218 if (!isVnni && (extraFlatDim != blockingFactor))
219 return rewriter.notifyMatchFailure(
220 contractOp, "The K or reduction dim for flat layout should be 2.");
221
222 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
223 (lhsTy.getElementType().isSignlessInteger(8) &&
224 !accTy.getElementType().isSignlessInteger(32)))
225 return rewriter.notifyMatchFailure(contractOp,
226 "Only F32 for BF16 or Int32 for Int8 "
227 "accumulation type is supported.");
228
229 Value unitDimOperand =
230 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
231 Value nonUnitDimOperand =
232 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
233
234 // If the A or B matrix vector of the contact operation is not packed, then
235 // find it's pair contract operation and pack (shuffle) them to VNNI packed.
236 if (!isVnni) {
237 vector::ContractionOp pairContractOp;
238 Operation *nextOp = contractOp;
239 while ((nextOp = nextOp->getNextNode())) {
240 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
241
242 if (!contOp)
243 continue;
244
245 if (validatePairVectorContract(contractOp, contOp,
246 rhsHasMultipleNonUnitDims,
247 nonUnitDimValue)) {
248 pairContractOp = contOp;
249 break;
250 }
251 }
252
253 // If the accumulators are shuffled we get nullptr else the
254 // transfer_read or load operations.
255 Operation *accRead =
256 traceToVectorReadLikeParentOperation(contractOp.getAcc());
257
258 if (!pairContractOp &&
259 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
260 return rewriter.notifyMatchFailure(contractOp,
261 "Could not find a contract pair");
262
263 // Validate and shuffle the accumulator
264 if (accRead) {
265 // Trace back to the load or transfer_read operations of the contract
266 // accumulators.
267 Operation *accReadOp0 =
268 traceToVectorReadLikeParentOperation(contractOp.getAcc());
269 Operation *accReadOp1 =
270 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
271
272 // Iterate down to find the users of contact operations until it is
273 // store or transfer_write.
274 Operation *resultWriteOp0 =
275 traceToVectorWriteLikeUserOperation(contractOp.getResult());
276 Operation *resultWriteOp1 =
277 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
278
279 if (!accReadOp0 || !accReadOp1)
280 return rewriter.notifyMatchFailure(
281 contractOp,
282 "Operands doesn't have load or transfer_read as it's parent op");
283
284 if (!resultWriteOp0 || !resultWriteOp1)
285 return rewriter.notifyMatchFailure(
286 contractOp,
287 "The use of contract operations are neither vector.store "
288 "or transfer_write or has multiple users.");
289
290 if (contractOp->getBlock() == accReadOp1->getBlock() &&
291 contractOp->isBeforeInBlock(accReadOp1))
292 return rewriter.notifyMatchFailure(
293 contractOp,
294 "The load/read operation of pair contract operation is "
295 "after the contractOp");
296
297 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
298 resultWriteOp0->isBeforeInBlock(pairContractOp))
299 return rewriter.notifyMatchFailure(
300 contractOp, "The store/write operation of contract operation is "
301 "before the pair contract operation");
302 // Shuffle the accumulators of the contract operations.
303 LogicalResult readShuffle =
304 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
305 pairContractOp, nonUnitDimValue, accTy);
306
307 if (failed(readShuffle))
308 return rewriter.notifyMatchFailure(
309 contractOp, "Accumulator read is not by transfer_read or load");
310
311 // Shuffle the output of contract operations before it's use.
312 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
313 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
314
315 if (failed(writeShuffle))
316 return rewriter.notifyMatchFailure(
317 contractOp,
318 "Write to accumulator is not by transfer_write or store");
319 }
320
321 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
322 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
323 ? pairContractOp.getRhs()
324 : pairContractOp.getLhs();
325
326 // Get the non-packed A or B matrix's vector<32xbf16> elements.
327 Operation *nonUnitDimReadOp =
328 traceToVectorReadLikeParentOperation(nonUnitDimOperand);
329 Operation *nonUnitDimReadOpPairContract =
330 traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
331
332 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
333 return rewriter.notifyMatchFailure(
334 contractOp, "Could not find a valid contract pair");
335
336 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
337 ? contractOp.getRhsType()
338 : contractOp.getLhsType();
339
340 packNonUnitDimOperandToVNNI(
341 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
342 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
343 nonUnitDimTy);
344
345 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
346 : contractOp.getLhs();
347 }
348 }
349
350 rewriter.setInsertionPoint(contractOp);
351 auto loc = contractOp.getLoc();
352 auto castAcc = vector::ShapeCastOp::create(
353 rewriter, loc,
354 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
355 contractOp.getAcc());
356
357 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
358 ? contractOp.getRhsType()
359 : contractOp.getLhsType();
360 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
361 : contractOp.getRhsType();
362
363 Value dp;
364
365 auto castNonUnitDim = vector::ShapeCastOp::create(
366 rewriter, loc,
367 VectorType::get(blockingFactor * nonUnitDimValue,
368 nonUnitDimTy.getElementType()),
369 nonUnitDimOperand);
370
371 auto castUnitDim = vector::ShapeCastOp::create(
372 rewriter, loc,
373 VectorType::get(blockingFactor, unitDimTy.getElementType()),
374 unitDimOperand);
375 auto bitcastUnitDim = vector::BitCastOp::create(
376 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
377 castUnitDim);
378 auto broadcastUnitDim = vector::BroadcastOp::create(
379 rewriter, loc,
380 VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
381 bitcastUnitDim);
382 auto bitcastUnitDimPkType = vector::BitCastOp::create(
383 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
384
385 if (lhsTy.getElementType().isBF16()) {
386 dp = x86::avx512::DotBF16Op::create(
387 rewriter, loc,
388 VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
389 bitcastUnitDimPkType, castNonUnitDim);
390 }
391
392 if (lhsTy.getElementType().isSignlessInteger(8)) {
393 if (nonUnitDimAcc.front() == 16) {
394 dp = x86::avx10::AVX10DotInt8Op::create(
395 rewriter, loc,
396 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
397 castAcc, bitcastUnitDimPkType, castNonUnitDim);
398 } else {
399 dp = x86::avx::DotInt8Op::create(
400 rewriter, loc,
401 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
402 castAcc, bitcastUnitDimPkType, castNonUnitDim);
403 }
404 }
405
406 if (!dp)
407 return failure();
408
409 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
410 rewriter.replaceOp(contractOp, castDp);
411 return success();
412 }
413};
414
415} // namespace
416
418 RewritePatternSet &patterns) {
419 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
420}
return success()
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:376
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:283
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:186
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:146
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:234
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:342
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...