MLIR 22.0.0git
VectorContractToFMA.cpp
Go to the documentation of this file.
1//===- VectorContractToFMA.cpp --------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13
15#include "mlir/IR/Dominance.h"
17
18#include "mlir/Pass/Pass.h"
20
21using namespace mlir;
22using namespace mlir::vector;
23using namespace mlir::x86vector;
24
25namespace {
26
27// Implements outer product contraction as a sequence of broadcast and
28// FMA operations.
29//
30// For example - for F32 type:
31// ```
32// vector.contract <1x1xf32>, <1x16xf32> into <1x16xf32>
33// ```
34// to
35// ```
36// vector.broadcast %lhs to <16xf32>
37// vector.fma vector<16xf32>
38// ```
39struct VectorContractToFMA : public OpRewritePattern<vector::ContractionOp> {
40 using OpRewritePattern<vector::ContractionOp>::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
43 PatternRewriter &rewriter) const override {
44
45 if (contractOp.getKind() != vector::CombiningKind::ADD)
46 return rewriter.notifyMatchFailure(contractOp,
47 "Expects add combining kind.");
48
49 VectorType lhsTy = contractOp.getLhsType();
50 if (!lhsTy.getElementType().isF32())
51 return rewriter.notifyMatchFailure(contractOp,
52 "Only F32 lowering is supported.");
53
54 ArrayRef<int64_t> lhsShape = lhsTy.getShape();
55 llvm::SmallVector<int64_t> nonUnitDimLhs;
56 llvm::copy_if(lhsShape, std::back_inserter(nonUnitDimLhs),
57 [](int64_t dim) { return dim != 1; });
58
59 VectorType rhsTy = contractOp.getRhsType();
60 ArrayRef<int64_t> rhsShape = rhsTy.getShape();
61 llvm::SmallVector<int64_t> nonUnitDimRhs;
62 llvm::copy_if(rhsShape, std::back_inserter(nonUnitDimRhs),
63 [](int64_t dim) { return dim != 1; });
64
65 if (nonUnitDimLhs.size() > 0 && nonUnitDimRhs.size() > 0)
66 return rewriter.notifyMatchFailure(
67 contractOp, "Excepts unit dimensions for either LHS or RHS shape.");
68
69 if (nonUnitDimLhs.size() != 1 && nonUnitDimRhs.size() != 1)
70 return rewriter.notifyMatchFailure(
71 contractOp,
72 "Excepts a one non-unit A/B dimension for either LHS or RHS shape.");
73
74 VectorType accTy = dyn_cast<VectorType>(contractOp.getAccType());
75 if (!accTy)
76 return rewriter.notifyMatchFailure(contractOp,
77 "Accmulator is not a vector type");
78
79 if (!accTy.getElementType().isF32())
80 return rewriter.notifyMatchFailure(contractOp,
81 "Accmulator should be F32 type.");
82
83 ArrayRef<int64_t> accShape = accTy.getShape();
84 llvm::SmallVector<int64_t> nonUnitDimAcc;
85 llvm::copy_if(accShape, std::back_inserter(nonUnitDimAcc),
86 [](int64_t dim) { return dim != 1; });
87 if (nonUnitDimAcc.size() != 1)
88 return rewriter.notifyMatchFailure(
89 contractOp, "A or B dimension should be non-unit.");
90
91 // Lowers vector.contract into a broadcast+FMA sequence.
92 auto loc = contractOp.getLoc();
93 auto castAcc = vector::ShapeCastOp::create(
94 rewriter, loc,
95 VectorType::get(nonUnitDimAcc.front(), accTy.getElementType()),
96 contractOp.getAcc());
97
98 vector::FMAOp fma;
99
100 // Broadcast the unit-dimension LHS or RHS to match the vector length of the
101 // corresponding non-unit dimension on the other operand. For example,
102 // if LHS has type vector<1x1xf32> and RHS has type vector<1x16xf32>, we
103 // broadcast the LHS to vector<1x16xf32>. In the opposite case (non-unit
104 // dimension on the LHS), we broadcast the RHS instead.
105 if (nonUnitDimRhs.size() > 0) {
106 auto castLhs = vector::ShapeCastOp::create(
107 rewriter, loc, VectorType::get(1, lhsTy.getElementType()),
108 contractOp.getLhs());
109 auto castRhs = vector::ShapeCastOp::create(
110 rewriter, loc,
111 VectorType::get(nonUnitDimRhs.front(), rhsTy.getElementType()),
112 contractOp.getRhs());
113 auto broadcastLhs = vector::BroadcastOp::create(
114 rewriter, loc, castRhs.getResult().getType(), castLhs);
115 fma =
116 vector::FMAOp::create(rewriter, loc, broadcastLhs, castRhs, castAcc);
117 } else {
118 auto castLhs = vector::ShapeCastOp::create(
119 rewriter, loc,
120 VectorType::get(nonUnitDimLhs.front(), lhsTy.getElementType()),
121 contractOp.getLhs());
122 auto castRhs = vector::ShapeCastOp::create(
123 rewriter, loc, VectorType::get(1, rhsTy.getElementType()),
124 contractOp.getRhs());
125 auto broadcastRhs = vector::BroadcastOp::create(
126 rewriter, loc, castLhs.getResult().getType(), castRhs);
127 fma =
128 vector::FMAOp::create(rewriter, loc, castLhs, broadcastRhs, castAcc);
129 }
130
131 auto castFma = vector::ShapeCastOp::create(rewriter, loc, accTy, fma);
132 rewriter.replaceOp(contractOp, castFma);
133
134 return success();
135 }
136};
137
138} // namespace
139
142 patterns.add<VectorContractToFMA>(patterns.getContext());
143}
return success()
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void populateVectorContractToFMAPatterns(RewritePatternSet &patterns)
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...