33static bool isNonUnitDimOperandShuffled(
Value nonUnitDimOperand) {
35 if (isa<vector::ShuffleOp>(defOp))
38 if (isa<vector::ShapeCastOp>(defOp)) {
40 if (isa<vector::ShuffleOp>(defOpShpCst))
54 if (mlir::isa<mlir::vector::ContractionOp>(user) ||
55 mlir::isa<mlir::scf::ForOp>(user)) {
66 mlir::vector::ContractionOp contractA,
67 mlir::vector::ContractionOp contractB,
69 mlir::VectorType Ty) {
83 auto elemTy = Ty.getElementType();
84 auto flatTy = mlir::VectorType::get(nonUnitDimAcc, elemTy);
86 auto castA = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
88 auto castB = mlir::vector::ShapeCastOp::create(rewriter, loc, flatTy,
91 static constexpr int64_t maskLo[] = {
92 0, 32, 1, 33, 2, 34, 3, 35, 8, 40, 9, 41, 10, 42, 11, 43,
93 16, 48, 17, 49, 18, 50, 19, 51, 24, 56, 25, 57, 26, 58, 27, 59};
94 static constexpr int64_t maskHi[] = {
95 4, 36, 5, 37, 6, 38, 7, 39, 12, 44, 13, 45, 14, 46, 15, 47,
96 20, 52, 21, 53, 22, 54, 23, 55, 28, 60, 29, 61, 30, 62, 31, 63};
98 auto shuffleLo = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
100 auto shuffleHi = mlir::vector::ShuffleOp::create(rewriter, loc, flatTy, castA,
103 auto newA = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleLo);
104 auto newB = mlir::vector::ShapeCastOp::create(rewriter, loc, Ty, shuffleHi);
106 rewriteUses(opA->
getResult(0), newA.getResult(), contractA, rewriter);
107 rewriteUses(opB->
getResult(0), newB.getResult(), contractB, rewriter);
141struct VectorContractToPackedTypeDotProduct
143 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
145 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
146 PatternRewriter &rewriter)
const override {
148 if (contractOp.getKind() != vector::CombiningKind::ADD)
150 "Expects add combining kind.");
152 VectorType lhsTy = contractOp.getLhsType();
153 if (!lhsTy.getElementType().isBF16() &&
154 !lhsTy.getElementType().isSignlessInteger(8))
156 contractOp,
"Only BF16/Int8 lowering is supported.");
158 unsigned int blockingFactor = lhsTy.getElementType().isBF16() ? 2 : 4;
161 contractOp.getIndexingMapsArray(), blockingFactor);
163 if (lhsTy.getElementType().isSignlessInteger(8) && !isVnni)
166 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
170 ArrayRef<int64_t> accShape = accTy.getShape();
171 llvm::SmallVector<int64_t> nonUnitDimAcc;
172 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
173 [](int64_t dim) {
return dim != 1; });
174 if (nonUnitDimAcc.size() != 1)
176 contractOp,
"A or B should be a non-unit dim in acc.");
178 int64_t nonUnitDimValue = nonUnitDimAcc.front();
181 if (lhsTy.getElementType().isBF16() && nonUnitDimValue != 4 &&
182 nonUnitDimValue != 8 && nonUnitDimValue != 16)
184 contractOp,
"BF16 dot-product operation expects non-unit (LHR or "
185 "RHS) dim and acc dim of size 4/8/16.");
187 if (lhsTy.getElementType().isSignlessInteger(8) && nonUnitDimValue != 4 &&
188 nonUnitDimValue != 8 && nonUnitDimValue != 16 &&
189 nonUnitDimAcc.front() == nonUnitDimValue)
191 contractOp,
"Int8 dot-product operation expects non-unit (LHR or "
192 "RHS) dim and acc dim of size 4/8/16.");
194 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
195 llvm::SmallVector<int64_t> nonUnitDimLhs;
196 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
197 [](int64_t dim) {
return dim != 1; });
199 VectorType rhsTy = contractOp.getRhsType();
200 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
201 llvm::SmallVector<int64_t> nonUnitDimRhs;
202 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
203 [](int64_t dim) {
return dim != 1; });
205 if ((nonUnitDimLhs.size() - 1) > 0 && (nonUnitDimRhs.size() - 1) > 0)
207 "Excepts unit dimensions for either "
208 "LHS or RHS shape.");
210 if ((nonUnitDimLhs.size() - 1) != 1 && (nonUnitDimRhs.size() - 1) != 1)
213 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
215 bool rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
216 int64_t extraFlatDim = rhsHasMultipleNonUnitDims ? nonUnitDimLhs.front()
217 : nonUnitDimRhs.front();
219 if (!isVnni && (extraFlatDim != blockingFactor))
221 contractOp,
"The K or reduction dim for flat layout should be 2.");
223 if ((lhsTy.getElementType().isBF16() && !accTy.getElementType().isF32()) ||
224 (lhsTy.getElementType().isSignlessInteger(8) &&
225 !accTy.getElementType().isSignlessInteger(32)))
227 "Only F32 for BF16 or Int32 for Int8 "
228 "accumulation type is supported.");
230 Value unitDimOperand =
231 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
232 Value nonUnitDimOperand =
233 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
238 vector::ContractionOp pairContractOp;
239 Operation *nextOp = contractOp;
240 while ((nextOp = nextOp->getNextNode())) {
241 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
247 rhsHasMultipleNonUnitDims,
249 pairContractOp = contOp;
259 if (!pairContractOp &&
260 (!isNonUnitDimOperandShuffled(nonUnitDimOperand) || accRead))
262 "Could not find a contract pair");
268 Operation *accReadOp0 =
270 Operation *accReadOp1 =
275 Operation *resultWriteOp0 =
277 Operation *resultWriteOp1 =
280 if (!accReadOp0 || !accReadOp1)
283 "Operands doesn't have load or transfer_read as it's parent op");
285 if (!resultWriteOp0 || !resultWriteOp1)
288 "The use of contract operations are neither vector.store "
289 "or transfer_write or has multiple users.");
291 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
292 contractOp->isBeforeInBlock(accReadOp1))
295 "The load/read operation of pair contract operation is "
296 "after the contractOp");
298 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
301 contractOp,
"The store/write operation of contract operation is "
302 "before the pair contract operation");
304 LogicalResult readShuffle =
306 pairContractOp, nonUnitDimValue, accTy);
310 contractOp,
"Accumulator read is not by transfer_read or load");
314 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDimValue, accTy);
319 "Write to accumulator is not by transfer_write or store");
322 if (!isNonUnitDimOperandShuffled(nonUnitDimOperand)) {
323 Value nonUnitDimOperandPairContract = rhsHasMultipleNonUnitDims
324 ? pairContractOp.getRhs()
325 : pairContractOp.getLhs();
328 Operation *nonUnitDimReadOp =
330 Operation *nonUnitDimReadOpPairContract =
333 if (!nonUnitDimReadOp || !nonUnitDimReadOpPairContract)
335 contractOp,
"Could not find a valid contract pair");
337 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
338 ? contractOp.getRhsType()
339 : contractOp.getLhsType();
341 packNonUnitDimOperandToVNNI(
342 rewriter, nonUnitDimReadOp, nonUnitDimReadOpPairContract,
343 contractOp, pairContractOp, blockingFactor * nonUnitDimValue,
346 nonUnitDimOperand = rhsHasMultipleNonUnitDims ? contractOp.getRhs()
347 : contractOp.getLhs();
352 auto loc = contractOp.getLoc();
353 auto castAcc = vector::ShapeCastOp::create(
355 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
356 contractOp.getAcc());
358 VectorType nonUnitDimTy = rhsHasMultipleNonUnitDims
359 ? contractOp.getRhsType()
360 : contractOp.getLhsType();
361 VectorType unitDimTy = rhsHasMultipleNonUnitDims ? contractOp.getLhsType()
362 : contractOp.getRhsType();
366 auto castNonUnitDim = vector::ShapeCastOp::create(
368 VectorType::get(blockingFactor * nonUnitDimValue,
369 nonUnitDimTy.getElementType()),
372 auto castUnitDim = vector::ShapeCastOp::create(
374 VectorType::get(blockingFactor, unitDimTy.getElementType()),
376 auto bitcastUnitDim = vector::BitCastOp::create(
379 auto broadcastUnitDim = vector::BroadcastOp::create(
383 auto bitcastUnitDimPkType = vector::BitCastOp::create(
384 rewriter, loc, castNonUnitDim.getResult().getType(), broadcastUnitDim);
386 if (lhsTy.getElementType().isBF16()) {
387 dp = x86::DotBF16Op::create(
389 VectorType::get(nonUnitDimValue, rewriter.
getF32Type()), castAcc,
390 bitcastUnitDimPkType, castNonUnitDim);
393 if (lhsTy.getElementType().isSignlessInteger(8)) {
394 if (nonUnitDimAcc.front() == 16) {
395 dp = x86::AVX10DotInt8Op::create(
398 castAcc, bitcastUnitDimPkType, castNonUnitDim);
400 dp = x86::DotInt8Op::create(
403 castAcc, bitcastUnitDimPkType, castNonUnitDim);
410 auto castDp = vector::ShapeCastOp::create(rewriter, loc, accTy, dp);
420 patterns.
add<VectorContractToPackedTypeDotProduct>(patterns.
getContext());
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Block * getBlock()
Returns the operation block that contains this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void moveOpAfter(Operation *op, Operation *existingOp)
Unlink this operation from its current block and insert it right after existingOp which may be in the...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
LogicalResult shuffleBeforeWriteLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, int64_t nonUnitDimAcc, VectorType accTy)
Operation * traceToVectorWriteLikeUserOperation(Value v)
void populateVectorContractToPackedTypeDotProductPatterns(RewritePatternSet &patterns)
bool isInVnniLayout(Operation *op, llvm::ArrayRef< AffineMap > indexingMaps, std::optional< unsigned > blockingFactor=std::nullopt)
Operation * traceToVectorReadLikeParentOperation(Value v)
LogicalResult shuffleAfterReadLikeOp(PatternRewriter &rewriter, Operation *opA, Operation *opB, vector::ContractionOp contractA, vector::ContractionOp contractB, int64_t nonUnitDimAcc, VectorType accTy)
bool validatePairVectorContract(vector::ContractionOp contractOp, vector::ContractionOp pairContOp, bool rhsHasMultipleNonUnitDims, int64_t nonUnitDimValue)
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...