36 std::optional<unsigned> blockingFactor) {
38 FailureOr<linalg::ContractionDimensions> dims =
45 auto typeA = dyn_cast<ShapedType>(matA.getType());
46 auto typeB = dyn_cast<ShapedType>(matB.getType());
47 unsigned rankA = typeA.getRank();
48 unsigned rankB = typeB.getRank();
50 if (rankA < 3 || rankB < 3)
55 if (dims->k.size() < 2)
64 if (failed(maybeIters))
70 auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 1));
71 auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 1));
72 if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
73 iteratorTypes[vnniDimA.getPosition()] !=
74 mlir::utils::IteratorType::reduction)
76 auto redDimA = dyn_cast<AffineDimExpr>(mapA.
getResult(rankA - 2));
77 auto redDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 3));
78 if (!redDimA || !redDimB || redDimA != redDimB ||
79 iteratorTypes[redDimA.getPosition()] !=
80 mlir::utils::IteratorType::reduction)
82 auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.
getResult(rankB - 2));
83 if (!parallelDimB || iteratorTypes[parallelDimB.getPosition()] !=
84 mlir::utils::IteratorType::parallel)
91 auto vnniDimSize = typeB.getShape().back();
92 if (vnniDimSize == ShapedType::kDynamic || vnniDimSize == 0 ||
95 if (typeA.getShape().back() != vnniDimSize)
97 if (blockingFactor && vnniDimSize != *blockingFactor)
101 if (typeA.getShape().end()[-2] != typeB.getShape().end()[-3])