228 if (contractOp.getKind() != vector::CombiningKind::ADD)
230 "Expects add combining kind.");
235 VectorType lhsTy = contractOp.getLhsType();
236 if (!lhsTy.getElementType().isBF16())
238 "Only BF16 lowering is supported.");
241 contractOp.getIndexingMapsArray(),
244 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
248 if (!accTy.getElementType().isF32())
250 contractOp,
"Only F32 acumulation supported for BF16 type.");
254 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
255 [](
int64_t dim) {
return dim != 1; });
256 if (nonUnitDimAcc.size() != 1)
258 contractOp,
"A or B should be a non-unit dim in acc.");
262 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
263 [](
int64_t dim) {
return dim != 1; });
265 VectorType rhsTy = contractOp.getRhsType();
268 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
269 [](
int64_t dim) {
return dim != 1; });
271 if (isVnni && (nonUnitDimLhs.size() - 1) > 0 &&
272 (nonUnitDimRhs.size() - 1) > 0)
274 "Excepts unit dimensions for either "
275 "LHS or RHS shape other than VNNI.");
277 if (isVnni && (nonUnitDimLhs.size() - 1) != 1 &&
278 (nonUnitDimRhs.size() - 1) != 1)
281 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
283 if (!isVnni && nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
285 "Excepts unit dimensions for either "
286 "LHS or RHS shape.");
288 if (!isVnni && nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
291 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
294 unsigned int nonUnitDim = nonUnitDimAcc.front();
295 if (nonUnitDim != 4 && nonUnitDim != 8)
297 contractOp,
"BF16 packed load operation expects non-unit (LHR or "
298 "RHS) dim and acc dim of size 4/8.");
303 contractOp,
"The LHS or RHS is in an invalid format. Either it has "
305 "a non-identity permutation map, a non-zero VNNI offset, "
307 "source, or a non-unit VNNI stride");
311 auto loc = contractOp.getLoc();
320 bool rhsHasMultipleNonUnitDims = nonUnitDimRhs.size() > 0;
322 rhsHasMultipleNonUnitDims = (nonUnitDimRhs.size() - 1) > 0;
327 rhsHasMultipleNonUnitDims ? contractOp.getLhs() : contractOp.getRhs();
329 rhsHasMultipleNonUnitDims ? contractOp.getRhs() : contractOp.getLhs();
332 rhsHasMultipleNonUnitDims ? rhsShape : lhsShape;
339 vector::ContractionOp pairContractOp;
342 while ((nextOp = nextOp->getNextNode())) {
343 auto contOp = dyn_cast<vector::ContractionOp>(nextOp);
349 rhsHasMultipleNonUnitDims,
350 nonUnitDimAcc.front())) {
351 pairContractOp = contOp;
371 if (!accReadOp0 || !accReadOp1)
374 "Operand doesn't have load or transfer_read as its parent op");
376 if (!resultWriteOp0 || !resultWriteOp1)
379 "The use of contract operations are neither vector.store "
380 "or transfer_write or has multiple users");
382 if (contractOp->getBlock() == accReadOp1->
getBlock() &&
383 contractOp->isBeforeInBlock(accReadOp1))
385 contractOp,
"The load/read operation of pair contract operation is "
386 "after the contractOp");
388 if (pairContractOp->getBlock() == resultWriteOp0->
getBlock() &&
391 contractOp,
"The store/write operation of contract operation is "
392 "before the pair contract operation");
398 loc, rewriter, unitSrc, nonUnitDimShape,
true, isVnni);
401 loc, rewriter, nonUnitSrc, nonUnitDimShape,
false, isVnni);
403 auto castAcc = vector::ShapeCastOp::create(
405 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
406 contractOp.getAcc());
408 VectorType::get(nonUnitDimAcc.front(), rewriter.
getF32Type());
426 LogicalResult readShuffle =
428 pairContractOp, nonUnitDim, accTy);
430 if (failed(readShuffle))
432 contractOp,
"Accumulator read is not by transfer_read or load");
436 rewriter, resultWriteOp0, resultWriteOp1, nonUnitDim, accTy);
438 if (failed(writeShuffle))
441 "Write to accumulator is not by transfer_write or store");
444 castAcc = vector::ShapeCastOp::create(
446 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
447 contractOp.getAcc());
449 auto loadBcstBF16ElementToF32 = x86vector::BcstToPackedF32Op::create(
450 rewriter, loc, dstType, unitDimSubview[0]);
451 auto loadEvenIdxElementF32 =
452 x86vector::CvtPackedEvenIndexedToF32Op::create(rewriter, loc, dstType,
453 nonUnitDimSubview[0]);
455 vector::FMAOp::create(rewriter, loc, loadBcstBF16ElementToF32,
456 loadEvenIdxElementF32, castAcc);
458 vector::ShapeCastOp::create(rewriter, loc, accTy, evenIdxFMA);
459 rewriter.
replaceOp(contractOp, castEvenFma);
462 auto pairContOpLoc = pairContractOp.getLoc();
463 VectorType accTyPairCont =
464 dyn_cast<VectorType>(pairContractOp.getAccType());
465 auto castAccPairCont = vector::ShapeCastOp::create(
466 rewriter, pairContOpLoc,
467 VectorType::get(nonUnitDimAcc.front(),
468 accTyPairCont.getElementType()),
469 pairContractOp.getAcc());
471 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
472 rewriter, pairContOpLoc, dstType, nonUnitDimSubview[0]);
473 auto oddIdxFMA = vector::FMAOp::create(
474 rewriter, pairContOpLoc, loadBcstBF16ElementToF32,
475 loadOddIdxElementF32, castAccPairCont);
476 auto castOddFma = vector::ShapeCastOp::create(rewriter, pairContOpLoc,
477 accTyPairCont, oddIdxFMA);
478 rewriter.
replaceOp(pairContractOp, castOddFma);
484 auto loadBcstOddIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
485 rewriter, loc, dstType, unitDimSubview[0]);
486 auto loadOddIdxElementF32 = x86vector::CvtPackedOddIndexedToF32Op::create(
487 rewriter, loc, dstType, nonUnitDimSubview[0]);
489 vector::FMAOp::create(rewriter, loc, loadBcstOddIdxElementToF32,
490 loadOddIdxElementF32, castAcc);
493 auto loadBcstEvenIdxElementToF32 = x86vector::BcstToPackedF32Op::create(
494 rewriter, loc, dstType, unitDimSubview[1]);
495 auto loadEvenIdxElementF32 = x86vector::CvtPackedEvenIndexedToF32Op::create(
496 rewriter, loc, dstType, nonUnitDimSubview[0]);
498 vector::FMAOp::create(rewriter, loc, loadBcstEvenIdxElementToF32,
499 loadEvenIdxElementF32, oddIdxFMA);
501 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);