MLIR 23.0.0git
VectorContractToPackedTypeDotProduct.cpp
Go to the documentation of this file.
1//===- VectorContractToPackedTypeDotProduct.cpp ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17
19#include "mlir/IR/Dominance.h"
21
22#include "mlir/Pass/Pass.h"
24
25using namespace mlir;
26using namespace mlir::vector;
27using namespace mlir::x86;
28
29namespace {
30
31// Returns true if the A or B matrix vector is packed (shuffled) to
32// VNNI layout, already.
33static bool isNonUnitDimOperandShuffled(Value nonUnitDimOperand) {
34 if (Operation *defOp = nonUnitDimOperand.getDefiningOp()) {
35 if (isa<vector::ShuffleOp>(defOp))
36 return true;
37
38 if (isa<vector::ShapeCastOp>(defOp)) {
39 Operation *defOpShpCst = defOp->getOperand(0).getDefiningOp();
40 if (isa<vector::ShuffleOp>(defOpShpCst))
41 return true;
42 }
43 }
44
45 return false;
46}
47
48static void rewriteUses(mlir::Value oldVal, mlir::Value newVal,
49 mlir::Operation *targetContract,
50 mlir::PatternRewriter &rewriter) {
51 for (mlir::OpOperand &use : llvm::make_early_inc_range(oldVal.getUses())) {
52
53 mlir::Operation *user = use.getOwner();
54 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
55 mlir::isa<mlir::scf::ForOp>(user)) {
56 use.set(newVal);
57 }
58 }
59}
60
61// Function to convert the flat layout A or B matrix vector<32xbf16>
62// into VNNI packed layout using the vpunpack operations
63static void packNonUnitDimOperandToVNNI(mlir::PatternRewriter &rewriter,
64 mlir::Operation *opA,
65 mlir::Operation *opB,
66 mlir::vector::ContractionOp contractA,
67 mlir::vector::ContractionOp contractB,
68 int64_t nonUnitDimAcc,
69 mlir::VectorType Ty) {
70
71 bool opABeforeopB = opA->isBeforeInBlock(opB);
72
73 if (opABeforeopB)
74 rewriter.moveOpAfter(opB, opA);
75 else
76 rewriter.moveOpAfter(opA, opB);
77
78 mlir::Operation *insertAfter = opABeforeopB ? opB : opA;
79
80 rewriter.setInsertionPointAfter(insertAfter);
81 mlir::Location loc = insertAfter->getLoc();
82
83 auto elemTy = Ty.getElementType();
84 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
85
86 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
87 opA->getResult(0));
88 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
89 opB->getResult(0));
90
91 static constexpr int64_t maskLo[] = {
92 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
93 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
94 static constexpr int64_t maskHi[] = {
95 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
96 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
97
98 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
99 castB, maskLo);
100 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
101 castB, maskHi);
102
103 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
104 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
105
106 rewriteUses(opA->getResult(0), newA.getResult(), contractA, rewriter);
107 rewriteUses(opB->getResult(0), newB.getResult(), contractB, rewriter);
108}
109
110// Implements packed type outer product contraction as a sequence
111// of broadcast and packed dot-product operations.
112//
113// For example - for bf16 type (VNNI):
114// ```
115// vector.contract <1x1x2xbf16>, <1x16x2xbf16> into <1x16xf32>
116// ```
117// to
118// ```
119// vector.broadcast %lhs to <32xbf16>
120// x86.avx512.dot vector<32xbf16> -> vector<16xf32>
121// ```
122//
123// For example - for bf16 type (Flat layout):
124// ```
125// %1 = vector.load -> <2x16xbf16>
126// %2 = vector.load -> <2x16xbf16>
127// vector.contract <1x2xbf16>, %1 into <1x16xf32>
128// vector.contract <1x2xbf16>, %2 into <1x16xf32>
129// ```
130// to
131// ```
132// %1 = vector.load -> <2x16xbf16>
133// %2 = vector.load -> <2x16xbf16>
134// %3 = vector.shuffle %1, %2 [0, 32, 1, ... 27, 59]
135// %4 = vector.shuffle %1, %2 [4, 36, 5, ... 31, 63]
136// vector.broadcast %lhs to <32xbf16>
137// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
138// vector.broadcast %lhs to <32xbf16>
139// x86.avx512.dot vector<32xbf16>, %3 -> vector<16xf32>
140// ```
141struct VectorContractToPackedTypeDotProduct
142 : public OpRewritePattern<vector::ContractionOp> {
143 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
144
145 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 PatternRewriter &rewriter) const override {
147
148 if (contractOp.getKind() != vector::CombiningKind::ADD)
149 return rewriter.notifyMatchFailure(contractOp,
150 "Expects add combining kind.");
151
152 VectorType lhsTy = contractOp.getLhsType();
153 if (!lhsTy.getElementType().isBF16() &&
154 !lhsTy.getElementType().isSignlessInteger(8))
155 return rewriter.notifyMatchFailure(
156 contractOp, "Only BF16/Int8 lowering is supported.");
157
158 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
159 bool isVnni =
160 isInVnniLayout(contractOp.getOperation(),
161 contractOp.getIndexingMapsArray(), blockingFactor);
162
163 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
164 return failure();
165
166 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
167 if (!accTy)
168 return rewriter.notifyMatchFailure(contractOp, "Wrong accmulator type.");
169
170 ArrayRef<int64_t> accShape = accTy.getShape();
171 llvm::SmallVector<int64_t> nonUnitDimAcc;
172 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
173 [](int64_t dim) { return dim != 1; });
174 if (nonUnitDimAcc.size() != 1)
175 return rewriter.notifyMatchFailure(
176 contractOp, "A or B should be a non-unit dim in acc.");
177
178 int64_t nonUnitDimValue = nonUnitDimAcc.front();
179 // Non-unit dimensions should match the vector length of BF16 or Int8
180 // dot-product.
181 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
182 nonUnitDimValue != 8 && nonUnitDimValue != 16)
183 return rewriter.notifyMatchFailure(
184 contractOp, "BF16 dot-product operation expects non-unit (LHR or "
185 "RHS) dim and acc dim of size 4/8/16.");
186
187 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
188 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
189 nonUnitDimAcc.front() == nonUnitDimValue)
190 return rewriter.notifyMatchFailure(
191 contractOp, "Int8 dot-product operation expects non-unit (LHR or "
192 "RHS) dim and acc dim of size 4/8/16.");
193
194 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
195 llvm::SmallVector<int64_t> nonUnitDimLhs;
196 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
197 [](int64_t dim) { return dim != 1; });
198
199 VectorType rhsTy = contractOp.getRhsType();
200 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
201 llvm::SmallVector<int64_t> nonUnitDimRhs;
202 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
203 [](int64_t dim) { return dim != 1; });
204
205 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
206 return rewriter.notifyMatchFailure(contractOp,
207 "Excepts unit dimensions for either "
208 "LHS or RHS shape.");
209
210 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
211 return rewriter.notifyMatchFailure(
212 contractOp,
213 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
214
215 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
216 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
217 : nonUnitDimRhs.front();
218
219 if (!isVnni && (extraFlatDim != blockingFactor))
220 return rewriter.notifyMatchFailure(
221 contractOp, "The K or reduction dim for flat layout should be 2.");
222
223 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
224 (lhsTy.getElementType().isSignlessInteger(8) &&
225 !accTy.getElementType().isSignlessInteger(32)))
226 return rewriter.notifyMatchFailure(contractOp,
227 "Only F32 for BF16 or Int32 for Int8 "
228 "accumulation type is supported.");
229
230 Value unitDimOperand =
231 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
232 Value nonUnitDimOperand =
233 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
234
235 // If the A or B matrix vector of the contact operation is not packed, then
236 // find it's pair contract operation and pack (shuffle) them to VNNI packed.
237 if (!isVnni) {
238 vector::ContractionOp pairContractOp;
239 Operation *nextOp = contractOp;
240 while ((nextOp = nextOp->getNextNode())) {
241 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
242
243 if (!contOp)
244 continue;
245
246 if (validatePairVectorContract(contractOp, contOp,
247 rhsHasMultipleNonUnitDims,
248 nonUnitDimValue)) {
249 pairContractOp = contOp;
250 break;
251 }
252 }
253
254 // If the accumulators are shuffled we get nullptr else the
255 // transfer_read or load operations.
256 Operation *accRead =
257 traceToVectorReadLikeParentOperation(contractOp.getAcc());
258
259 if (!pairContractOp &&
260 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
261 return rewriter.notifyMatchFailure(contractOp,
262 "Could not find a contract pair");
263
264 // Validate and shuffle the accumulator
265 if (accRead) {
266 // Trace back to the load or transfer_read operations of the contract
267 // accumulators.
268 Operation *accReadOp0 =
269 traceToVectorReadLikeParentOperation(contractOp.getAcc());
270 Operation *accReadOp1 =
271 traceToVectorReadLikeParentOperation(pairContractOp.getAcc());
272
273 // Iterate down to find the users of contact operations until it is
274 // store or transfer_write.
275 Operation *resultWriteOp0 =
276 traceToVectorWriteLikeUserOperation(contractOp.getResult());
277 Operation *resultWriteOp1 =
278 traceToVectorWriteLikeUserOperation(pairContractOp.getResult());
279
280 if (!accReadOp0 || !accReadOp1)
281 return rewriter.notifyMatchFailure(
282 contractOp,
283 "Operands doesn't have load or transfer_read as it's parent op");
284
285 if (!resultWriteOp0 || !resultWriteOp1)
286 return rewriter.notifyMatchFailure(
287 contractOp,
288 "The use of contract operations are neither vector.store "
289 "or transfer_write or has multiple users.");
290
291 if (contractOp->getBlock() == accReadOp1->getBlock() &&
292 contractOp->isBeforeInBlock(accReadOp1))
293 return rewriter.notifyMatchFailure(
294 contractOp,
295 "The load/read operation of pair contract operation is "
296 "after the contractOp");
297
298 if (pairContractOp->getBlock() == resultWriteOp0->getBlock() &&
299 resultWriteOp0->isBeforeInBlock(pairContractOp))
300 return rewriter.notifyMatchFailure(
301 contractOp, "The store/write operation of contract operation is "
302 "before the pair contract operation");
303 // Shuffle the accumulators of the contract operations.
304 LogicalResult readShuffle =
305 shuffleAfterReadLikeOp(rewriter, accReadOp0, accReadOp1, contractOp,
306 pairContractOp, nonUnitDimValue, accTy);
307
308 if (failed(readShuffle))
309 return rewriter.notifyMatchFailure(
310 contractOp, "Accumulator read is not by transfer_read or load");
311
312 // Shuffle the output of contract operations before it's use.
313 LogicalResult writeShuffle = shuffleBeforeWriteLikeOp(
314 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
315
316 if (failed(writeShuffle))
317 return rewriter.notifyMatchFailure(
318 contractOp,
319 "Write to accumulator is not by transfer_write or store");
320 }
321
322 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
323 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
324 ? pairContractOp.getRhs()
325 : pairContractOp.getLhs();
326
327 // Get the non-packed A or B matrix's vector<32xbf16> elements.
328 Operation *nonUnitDimReadOp =
329 traceToVectorReadLikeParentOperation(nonUnitDimOperand);
330 Operation *nonUnitDimReadOpPairContract =
331 traceToVectorReadLikeParentOperation(nonUnitDimOperandPairContract);
332
333 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
334 return rewriter.notifyMatchFailure(
335 contractOp, "Could not find a valid contract pair");
336
337 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
338 ? contractOp.getRhsType()
339 : contractOp.getLhsType();
340
341 packNonUnitDimOperandToVNNI(
342 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
343 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
344 nonUnitDimTy);
345
346 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
347 : contractOp.getLhs();
348 }
349 }
350
351 rewriter.setInsertionPoint(contractOp);
352 auto loc = contractOp.getLoc();
353 auto castAcc = vector::ShapeCastOp::create(
354 rewriter, loc,
355 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
356 contractOp.getAcc());
357
358 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
359 ? contractOp.getRhsType()
360 : contractOp.getLhsType();
361 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
362 : contractOp.getRhsType();
363
364 Value dp;
365
366 auto castNonUnitDim = vector::ShapeCastOp::create(
367 rewriter, loc,
368 VectorType::get(blockingFactor * nonUnitDimValue,
369 nonUnitDimTy.getElementType()),
370 nonUnitDimOperand);
371
372 auto castUnitDim = vector::ShapeCastOp::create(
373 rewriter, loc,
374 VectorType::get(blockingFactor, unitDimTy.getElementType()),
375 unitDimOperand);
376 auto bitcastUnitDim = vector::BitCastOp::create(
377 rewriter, loc, VectorType::get({1}, rewriter.getIntegerType(32)),
378 castUnitDim);
379 auto broadcastUnitDim = vector::BroadcastOp::create(
380 rewriter, loc,
381 VectorType::get({nonUnitDimValue}, rewriter.getIntegerType(32)),
382 bitcastUnitDim);
383 auto bitcastUnitDimPkType = vector::BitCastOp::create(
384 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
385
386 if (lhsTy.getElementType().isBF16()) {
387 dp = x86::DotBF16Op::create(
388 rewriter, loc,
389 VectorType::get(nonUnitDimValue, rewriter.getF32Type()), castAcc,
390 bitcastUnitDimPkType, castNonUnitDim);
391 }
392
393 if (lhsTy.getElementType().isSignlessInteger(8)) {
394 if (nonUnitDimAcc.front() == 16) {
395 dp = x86::AVX10DotInt8Op::create(
396 rewriter, loc,
397 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
398 castAcc, bitcastUnitDimPkType, castNonUnitDim);
399 } else {
400 dp = x86::DotInt8Op::create(
401 rewriter, loc,
402 VectorType::get(nonUnitDimValue, rewriter.getIntegerType(32)),
403 castAcc, bitcastUnitDimPkType, castNonUnitDim);
404 }
405 }
406
407 if (!dp)
408 return failure();
409
410 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
411 rewriter.replaceOp(contractOp, castDp);
412 return success();
413 }
414};
415
416} // namespace
417
419 RewritePatternSet &patterns) {
420 patterns.add<VectorContractToPackedTypeDotProduct>(patterns.getContext());
421}
return success()
FloatType getF32Type()
Definition Builders.cpp:47
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents an operand of an operation.
Definition Value.h:254
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Value getOperand(unsigned idx)
Definition Operation.h:379
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:436
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:283
Operation * traceToVectorWriteLikeUserOperation(Value v)
Definition X86Utils.cpp:186
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Definition X86Utils.cpp:42
Operation * traceToVectorReadLikeParentOperation(Value v)
Definition X86Utils.cpp:146
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
Definition X86Utils.cpp:234
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Definition X86Utils.cpp:340
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...