40 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
42 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
43 PatternRewriter &rewriter)
const override {
45 if (contractOp.getKind() != vector::CombiningKind::ADD)
47 "Expects add combining kind.");
49 VectorType lhsTy = contractOp.getLhsType();
50 if (!lhsTy.getElementType().isF32())
52 "Only F32 lowering is supported.");
54 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
55 llvm::SmallVector<int64_t> nonUnitDimLhs;
56 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
57 [](int64_t dim) {
return dim != 1; });
59 VectorType rhsTy = contractOp.getRhsType();
60 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
61 llvm::SmallVector<int64_t> nonUnitDimRhs;
62 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
63 [](int64_t dim) {
return dim != 1; });
65 if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
67 contractOp,
"Excepts unit dimensions for either LHS or RHS shape.");
69 if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
72 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
74 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
77 "Accmulator is not a vector type");
79 if (!accTy.getElementType().isF32())
81 "Accmulator should be F32 type.");
83 ArrayRef<int64_t> accShape = accTy.getShape();
84 llvm::SmallVector<int64_t> nonUnitDimAcc;
85 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
86 [](int64_t dim) {
return dim != 1; });
87 if (nonUnitDimAcc.size() != 1)
89 contractOp,
"A or B dimension should be non-unit.");
92 auto loc = contractOp.getLoc();
93 auto castAcc = vector::ShapeCastOp::create(
95 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
105 if (nonUnitDimRhs.size() > 0) {
106 auto castLhs = vector::ShapeCastOp::create(
107 rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
108 contractOp.getLhs());
109 auto castRhs = vector::ShapeCastOp::create(
111 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
112 contractOp.getRhs());
113 auto broadcastLhs = vector::BroadcastOp::create(
114 rewriter, loc, castRhs.getResult().getType(), castLhs);
116 vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
118 auto castLhs = vector::ShapeCastOp::create(
120 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
121 contractOp.getLhs());
122 auto castRhs = vector::ShapeCastOp::create(
123 rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
124 contractOp.getRhs());
125 auto broadcastRhs = vector::BroadcastOp::create(
126 rewriter, loc, castLhs.getResult().getType(), castRhs);
128 vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
131 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...