32 #define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
46 template <
typename Op>
47 std::optional<Value> getExtOperand(
Value v) {
49 static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
50 "Must be instantiated with either sign- or zero- extension op");
56 if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
57 auto vTy = cast<VectorType>(v.
getType());
58 if (!vTy.getElementType().isSignlessInteger(8))
67 auto inOp = extOp.getIn();
68 auto inTy = dyn_cast<VectorType>(inOp.getType());
69 if (!inTy || !inTy.getElementType().isSignlessInteger(8))
72 auto outTy = dyn_cast<VectorType>(extOp.getType());
73 if (!outTy || !outTy.getElementType().isSignlessInteger(32))
93 return rewriter.
create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
95 return rewriter.
create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
97 return rewriter.
create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
98 case MMLA::MixedSwapped:
101 return rewriter.
create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
137 class LowerContractionToSVEI8MMPattern
141 LogicalResult matchAndRewrite(vector::ContractionOp op,
145 mlir::VectorType lhsType = op.getLhsType();
146 mlir::VectorType rhsType = op.getRhsType();
149 if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
152 auto M = lhsType.getDimSize(0);
153 auto N = rhsType.getDimSize(0);
154 auto K = rhsType.getDimSize(1);
161 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] ||
162 rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 ||
163 M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 ||
164 !rhsType.getScalableDims()[0])
172 if (op.getIndexingMapsArray()[0] !=
175 op.getIndexingMapsArray()[1] !=
178 op.getIndexingMapsArray()[2] !=
184 auto itTypes = op.getIteratorTypesArray();
185 if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
186 itTypes[1] != vector::IteratorType::parallel ||
187 itTypes[2] != vector::IteratorType::reduction)
189 op,
"iterator types do not correspond to matrix multiplication");
192 if (op.getKind() != vector::CombiningKind::ADD)
194 "combining kind is not an addition");
197 auto outTy = dyn_cast<VectorType>(op.getResultType());
198 if (!outTy || outTy.getElementType() != rewriter.
getI32Type())
200 "output type is not a vector of i32");
207 auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
209 mmlaOp = MMLA::Unsigned;
210 maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
214 op,
"LHS is not a sign- or zero- extended i8");
216 auto maybeRhs = getExtOperand<arith::ExtSIOp>(op.getRhs());
218 if (mmlaOp == MMLA::Unsigned)
219 mmlaOp = MMLA::Mixed;
222 mmlaOp = MMLA::MixedSwapped;
223 maybeRhs = getExtOperand<arith::ExtUIOp>(op.getRhs());
227 op,
"RHS is not a sign- or zero- extended i8");
237 for (int64_t i = 0; i < M; i += 2) {
239 auto r0 = rewriter.
create<vector::ExtractOp>(loc, *maybeLhs,
241 auto r1 = rewriter.
create<vector::ExtractOp>(loc, *maybeLhs,
244 auto t = rewriter.
create<vector::ShuffleOp>(
246 llvm::ArrayRef<int64_t>{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
249 auto s = rewriter.
create<vector::ScalableInsertOp>(
250 loc, t, rewriter.
create<ub::PoisonOp>(loc, nxv16i8), 0);
252 auto r = rewriter.
create<arm_sve::DupQLaneOp>(loc, s, 0);
253 lhsTile.push_back(r);
257 auto rhs = rewriter.
create<vector::ShapeCastOp>(
265 for (int64_t
j = 0;
j < N;
j += 2)
267 rewriter.
create<vector::ScalableExtractOp>(loc, nxv16i8, rhs,
j * 8));
281 for (int64_t i = 0; i < M; i += 2) {
283 auto r0 = rewriter.
create<vector::ExtractOp>(loc, op.getAcc(),
285 auto r1 = rewriter.
create<vector::ExtractOp>(loc, op.getAcc(),
288 if (mmlaOp == MMLA::MixedSwapped) {
292 accTileVec = rewriter.
create<vector::InterleaveOp>(loc, r0, r1);
296 auto r0I64 = rewriter.
create<vector::BitCastOp>(loc, accRow64Ty, r0);
297 auto r1I64 = rewriter.
create<vector::BitCastOp>(loc, accRow64Ty, r1);
301 auto intrI64 = rewriter.
create<vector::InterleaveOp>(loc, r0I64, r1I64);
305 rewriter.
create<vector::BitCastOp>(loc, accRowX2Ty, intrI64);
308 for (int64_t
j = 0;
j < N;
j += 2)
309 accTile.push_back(rewriter.
create<vector::ScalableExtractOp>(
310 loc, nxv4i32, accTileVec,
j * 2));
315 for (int64_t i = 0; i < M / 2; ++i)
316 for (int64_t
j = 0;
j < N / 2; ++
j) {
317 Value mmla = createMMLA(rewriter, mmlaOp, loc, nxv4i32,
318 accTile[i * N / 2 +
j], lhsTile[i], rhsTile[
j]);
319 outTile.push_back(mmla);
323 Value result = rewriter.
create<ub::PoisonOp>(loc, op.getResultType());
324 for (int64_t i = 0; i < M / 2; ++i) {
326 Value row = rewriter.
create<ub::PoisonOp>(loc, accRowX2Ty);
327 for (int64_t
j = 0;
j < N / 2; ++
j)
328 row = rewriter.
create<vector::ScalableInsertOp>(
329 loc, outTile[i * N / 2 +
j], row,
j * 4);
336 if (mmlaOp == MMLA::MixedSwapped) {
337 auto tmp = rewriter.
create<vector::DeinterleaveOp>(loc, row);
338 out0 = tmp.getRes1();
339 out1 = tmp.getRes2();
342 auto row64 = rewriter.
create<vector::BitCastOp>(loc, accRowX264Ty, row);
343 auto deintr64 = rewriter.
create<vector::DeinterleaveOp>(loc, row64);
346 out0 = rewriter.
create<vector::BitCastOp>(loc, accRowTy,
348 out1 = rewriter.
create<vector::BitCastOp>(loc, accRowTy,
351 result = rewriter.
create<vector::InsertOp>(loc, out0, result, i * 2);
352 result = rewriter.
create<vector::InsertOp>(loc, out1, result, i * 2 + 1);
365 patterns.add<LowerContractionToSVEI8MMPattern>(context, 2);
static AffineMap getMultiDimMapWithTargets(unsigned numDims, ArrayRef< unsigned > targets, MLIRContext *context)
Returns an affine map with numDims input dimensions and results specified by targets.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateLowerContractionToSVEI8MMPatternPatterns(RewritePatternSet &patterns)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.