MLIR 23.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/Location.h"
25
26#define DEBUG_TYPE "vector-contract-lowering"
27
28using namespace mlir;
29using namespace mlir::vector;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34// Helper to find an index in an affine map.
35static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
36 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
37 int64_t idx = map.getDimPosition(i);
38 if (idx == index)
39 return i;
40 }
41 return std::nullopt;
42}
43
44// Helper to construct iterator types with one index removed.
46 int64_t index) {
48 for (const auto &it : llvm::enumerate(iteratorTypes)) {
49 int64_t idx = it.index();
50 if (idx == index)
51 continue;
52 results.push_back(it.value());
53 }
54 return results;
55}
56
57// Helper to construct an affine map with one index removed.
59 PatternRewriter &rewriter) {
60 auto *ctx = rewriter.getContext();
62 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
63 int64_t idx = map.getDimPosition(i);
64 if (idx == index)
65 continue;
66 // Re-insert remaining indices, but renamed when occurring
67 // after the removed index.
68 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
69 results.push_back(targetExpr);
70 }
71 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
72}
73
74// Helper method to possibly drop a dimension in a load.
75// TODO
76static Value reshapeLoad(Location loc, Value val, VectorType type,
78 PatternRewriter &rewriter) {
79 if (index == -1)
80 return val;
81
82 // At extraction dimension?
83 if (index == 0)
84 return vector::ExtractOp::create(rewriter, loc, val, pos);
85
86 // Unroll leading dimensions.
87 VectorType vType = VectorType::Builder(type).dropDim(0);
88 VectorType resType = VectorType::Builder(type).dropDim(index);
89 Value result = arith::ConstantOp::create(rewriter, loc, resType,
90 rewriter.getZeroAttr(resType));
91 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
93 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
94 result = vector::InsertOp::create(rewriter, loc, load, result, d);
95 }
96 return result;
97}
98
99// Helper method to possibly drop a dimension in a store.
100// TODO
102 VectorType type, int64_t index, int64_t pos,
103 PatternRewriter &rewriter) {
104 // Unmodified?
105 if (index == -1)
106 return val;
107 // At insertion dimension?
108 if (index == 0)
109 return vector::InsertOp::create(rewriter, loc, val, result, pos);
110
111 // Unroll leading dimensions.
112 VectorType vType = VectorType::Builder(type).dropDim(0);
113 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114 Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
115 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
116 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
117 result = vector::InsertOp::create(rewriter, loc, sto, result, d);
118 }
119 return result;
120}
121
122/// Helper to create arithmetic operation associated with a kind of contraction.
123static std::optional<Value>
125 vector::CombiningKind kind, PatternRewriter &rewriter,
126 bool isInt, Value mask = Value(),
127 arith::FastMathFlagsAttr fmf = {}) {
128 using vector::CombiningKind;
129 Value mul;
130
131 if (isInt) {
132 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
133 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
134 // Only valid for floating point types.
135 return std::nullopt;
136 mul = arith::MulIOp::create(rewriter, loc, x, y);
137 } else {
138 // Float case.
139 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
140 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
141 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
142 kind == CombiningKind::XOR)
143 // Only valid for integer types.
144 return std::nullopt;
145 // Special case for fused multiply-add.
146 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
147 Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
148 if (mask)
149 // The fma op doesn't need explicit masking. However, fma ops used in
150 // reductions must preserve previous 'acc' values for masked-out lanes.
151 fma = selectPassthru(rewriter, mask, fma, acc);
152 return fma;
153 }
154 mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
155 }
156
157 if (!acc)
158 return std::optional<Value>(mul);
159
160 return makeArithReduction(rewriter, loc, kind, mul, acc, fmf, mask);
161}
162
163/// Return the positions of the reductions in the given map.
165 ArrayAttr iteratorTypes) {
166 SmallVector<int64_t> dimsIdx;
167 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
168 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
169 dimsIdx.push_back(i);
170 }
171 return dimsIdx;
172}
173
174/// Look for a given dimension in an affine map and return its position. Return
175/// std::nullopt if the dimension is not in the map results.
176static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
177 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178 if (map.getDimPosition(i) == dim)
179 return i;
180 }
181 return std::nullopt;
182}
183
184/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
185/// operands `x` and `y`.
186static Value createAdd(Location loc, Value x, Value y, bool isInt,
187 PatternRewriter &rewriter,
188 arith::FastMathFlagsAttr fmf = {}) {
189 if (isInt)
190 return arith::AddIOp::create(rewriter, loc, x, y);
191 return arith::AddFOp::create(rewriter, loc, x, y, fmf);
192}
193
194/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
195/// operands `x and `y`.
196static Value createMul(Location loc, Value x, Value y, bool isInt,
197 PatternRewriter &rewriter,
198 arith::FastMathFlagsAttr fmf = {}) {
199 if (isInt)
200 return arith::MulIOp::create(rewriter, loc, x, y);
201 return arith::MulFOp::create(rewriter, loc, x, y, fmf);
202}
203
204namespace {
205
206/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
207/// semantics to a reduction_size-unrolled sequence:
208/// ```
209/// %at = vector.transpose %a, [1, 0]
210/// %bRow0 = vector.extract %b[0]
211/// %atRow0 = vector.extract %at[0]
212/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
213/// ...
214/// %bRowK = vector.extract %b[K]
215/// %atRowK = vector.extract %at[K]
216/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
217/// ```
218///
219/// This only kicks in when vectorContractLowering is set to OuterProduct and
220/// the vector.contract op is a row-major matrix multiply.
221class ContractionOpToOuterProductOpLowering
222 : public MaskableOpRewritePattern<vector::ContractionOp> {
223public:
224 using MaskableOpRewritePattern::MaskableOpRewritePattern;
225
226 using FilterConstraintType =
227 std::function<LogicalResult(vector::ContractionOp op)>;
228
229 static LogicalResult defaultFilter(vector::ContractionOp op) {
230 return success();
231 }
232
233 ContractionOpToOuterProductOpLowering(
234 vector::VectorContractLowering vectorContractLowering,
235 MLIRContext *context, PatternBenefit benefit = 1,
236 FilterConstraintType constraint = defaultFilter)
237 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
238 vectorContractLowering(vectorContractLowering),
239 filter(std::move(constraint)) {}
240
241 FailureOr<Value>
242 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
243 PatternRewriter &rewriter) const override;
244
245private:
246 /// Options to control the vector patterns.
247 vector::VectorContractLowering vectorContractLowering;
248 FilterConstraintType filter;
249};
250
251/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
252/// semantics to an output-size-unrolled sequence:
253/// ```
254/// %out = arith.constant ... : vector<MxNxelt_type>
255/// %bt = vector.transpose %b, [1, 0]
256/// %aRow0 = vector.extract %a[0]
257/// %btRow0 = vector.extract %bt[0]
258/// %c00 = vector.reduction %atRow0, %bRow0
259/// %out00 = vector.insert %c00, %out[0, 0]
260/// ...
261/// %aRowLast = vector.extract %at[M-1]
262/// %btRowLast = vector.extract %b[N-1]
263/// %cLastLast = vector.reduction %atRowLast, %bRowLast
264/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
265/// ```
266///
267/// This only kicks in when VectorTransformsOptions is set to Dot and
268/// the vector.contract op is a row-major matmul or matvec.
269class ContractionOpToDotLowering
270 : public MaskableOpRewritePattern<vector::ContractionOp> {
271public:
272 using MaskableOpRewritePattern::MaskableOpRewritePattern;
273
274 using FilterConstraintType =
275 std::function<LogicalResult(vector::ContractionOp op)>;
276
277 static LogicalResult defaultFilter(vector::ContractionOp op) {
278 return success();
279 }
280
281 ContractionOpToDotLowering(
282 vector::VectorContractLowering vectorContractLowering,
283 MLIRContext *context, PatternBenefit benefit = 1,
284 const FilterConstraintType &constraint = defaultFilter)
285 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
286 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
287
288 FailureOr<Value>
289 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
290 PatternRewriter &rewriter) const override;
291
292private:
293 /// Options to control the vector patterns.
294 vector::VectorContractLowering vectorContractLowering;
295 FilterConstraintType filter;
296};
297
298/// Progressive lowering of ContractionOp.
299///
300/// One:
301/// %x = vector.contract with at least one free/batch dimension
302/// is replaced by:
303/// %a = vector.contract with one less free/batch dimension
304/// %b = vector.contract with one less free/batch dimension
305/// ..
306/// %x = combine %a %b ..
307/// until a pure contraction is reached (no free/batch dimensions),
308/// which is replaced by a dot-product.
309///
310/// This only kicks in when either VectorTransformsOptions is set
311/// to Dot or when other contraction patterns fail.
312class ContractionOpLowering
313 : public MaskableOpRewritePattern<vector::ContractionOp> {
314public:
315 using MaskableOpRewritePattern::MaskableOpRewritePattern;
316 using FilterConstraintType =
317 std::function<LogicalResult(vector::ContractionOp op)>;
318
319 static LogicalResult defaultFilter(vector::ContractionOp op) {
320 return success();
321 }
322
323 ContractionOpLowering(
324 vector::VectorContractLowering vectorContractLoweringOption,
325 MLIRContext *context, PatternBenefit benefit = 1,
326 FilterConstraintType constraint = defaultFilter)
327 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
328 vectorContractLoweringOption(vectorContractLoweringOption),
329 filter(std::move(constraint)) {}
330
331 FailureOr<Value>
332 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
333 PatternRewriter &rewriter) const override;
334
335private:
336 /// Options to control the vector patterns.
337 vector::VectorContractLowering vectorContractLoweringOption;
338 FilterConstraintType filter;
339 // Lower one parallel dimension.
340 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
341 vector::ContractionOp op, int64_t lhsIndex,
342 int64_t rhsIndex, Value mask) const;
343 // Lower one reduction dimension.
344 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
345 vector::ContractionOp op, Value mask) const;
346};
347
348/// Generate a vector implementation for matmat, matvec and tmatvec.
349/// This unrolls outer-products along the reduction dimension.
350struct UnrolledOuterProductGenerator
351 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
352 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
353 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
354 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
355 res(op.getAcc()), lhsType(op.getLhsType()) {
356 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
357 if (maskableOp.isMasked())
358 mask = maskableOp.getMaskingOp().getMask();
359 }
360
361 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
362 if (!v)
363 return v;
364 return vector::TransposeOp::create(rewriter, loc, v, perm);
365 }
366
367 Value promote(Value v, Type dstElementType) {
368 Type elementType = v.getType();
369 auto vecType = dyn_cast<VectorType>(elementType);
370 if (vecType)
371 elementType = vecType.getElementType();
372 if (elementType == dstElementType)
373 return v;
374 Type promotedType = dstElementType;
375 if (vecType)
376 promotedType = vecType.clone(promotedType);
377 if (isa<FloatType>(dstElementType))
378 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
379 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
380 }
381
382 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
383 VectorType lhsType, int reductionSize,
384 std::optional<Value> maybeMask = std::nullopt) {
385 // Incremental support for masking.
386 if (mask && !maybeMask.has_value())
387 return failure();
388
389 Type resElementType = cast<VectorType>(res.getType()).getElementType();
390 for (int64_t k = 0; k < reductionSize; ++k) {
391 Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
392 Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
393 extractA = promote(extractA, resElementType);
394 extractB = promote(extractB, resElementType);
395 Value extractMask;
396 if (maybeMask.has_value() && maybeMask.value())
397 extractMask =
398 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
399
400 Operation *outerProdOp = vector::OuterProductOp::create(
401 rewriter, loc, res.getType(), extractA, extractB, res, kind);
402 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
403 }
404 return res;
405 }
406
407 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
408 /// dimension `reductionDim`. If the dimension is a scalable dimension,
409 /// returns "nullopt".
410 std::optional<int64_t> getReductionSize(VectorType vecType,
411 int64_t reductionDim) {
412 // Cannot unroll scalable dimension.
413 if (vecType.getScalableDims()[reductionDim])
414 return std::nullopt;
415 int64_t reductionSize = vecType.getDimSize(reductionDim);
416 assert(reductionSize > 0 &&
417 "Reduction dim must be a known static size to allow unrolling");
418 return reductionSize;
419 }
420
421 /// Two outer parallel, one inner reduction (matmat flavor).
422 FailureOr<Value> matmat() {
423 if (!iters({Par(), Par(), Red()}))
424 return failure();
425 // Set up the parallel/reduction structure in the right form.
426 AffineExpr m, n, k;
427 bindDims(rewriter.getContext(), m, n, k);
428
429 // Classical row-major matmul: Just permute the lhs.
430 if (layout({{m, k}, {k, n}, {m, n}})) {
431 if (auto reductionSize = getReductionSize(lhsType, 1)) {
432 // Note: `t` creates new IR. It must be nested within this `if` check
433 // so that no IR is created when then pattern returns "failure".
434 Value tLhs = t(lhs);
435 Value tMask = t(mask, {2, 0, 1});
436 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
437 }
438 }
439 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
440 if (layout({{m, k}, {n, k}, {m, n}})) {
441 if (auto reductionSize = getReductionSize(lhsType, 1)) {
442 Value tLhs = t(lhs);
443 Value tRhs = t(rhs);
444 Value tMask = t(mask, {2, 0, 1});
445 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
446 }
447 }
448 // No need to permute anything.
449 if (layout({{k, m}, {k, n}, {m, n}})) {
450 if (auto reductionSize = getReductionSize(lhsType, 0)) {
451 Value tMask = t(mask, {2, 0, 1});
452 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
453 }
454 }
455 // Just permute the rhs.
456 if (layout({{k, m}, {n, k}, {m, n}})) {
457 if (auto reductionSize = getReductionSize(lhsType, 0)) {
458 Value tRhs = t(rhs);
459 Value tMask = t(mask, {2, 0, 1});
460 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
461 }
462 }
463 // Transposed output: swap RHS and LHS.
464 // Classical row-major matmul: permute the lhs.
465 if (layout({{m, k}, {k, n}, {n, m}})) {
466 if (auto reductionSize = getReductionSize(lhsType, 1)) {
467 Value tLhs = t(lhs);
468 Value tMask = t(mask, {2, 0, 1});
469 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
470 }
471 }
472 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
473 if (layout({{m, k}, {n, k}, {n, m}})) {
474 if (auto reductionSize = getReductionSize(lhsType, 1)) {
475 Value tRhs = t(rhs);
476 Value tLhs = t(lhs);
477 Value tMask = t(mask, {2, 0, 1});
478 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
479 }
480 }
481 if (layout({{k, m}, {k, n}, {n, m}})) {
482 if (auto reductionSize = getReductionSize(lhsType, 0)) {
483 Value tMask = t(mask, {2, 0, 1});
484 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
485 }
486 }
487 if (layout({{k, m}, {n, k}, {n, m}})) {
488 if (auto reductionSize = getReductionSize(lhsType, 0)) {
489 Value tRhs = t(rhs);
490 Value tMask = t(mask, {2, 0, 1});
491 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
492 }
493 }
494 return failure();
495 }
496
497 //
498 // One outer parallel, one inner reduction (matvec flavor).
499 // Mask needs to be transposed everywhere to turn the reduction dimension
500 // outermost as required by outerproduct.
501 //
502 FailureOr<Value> matvec() {
503 if (!iters({Par(), Red()}))
504 return failure();
505 AffineExpr m, k;
506 bindDims(rewriter.getContext(), m, k);
507
508 // Case mat-vec: transpose.
509 if (layout({{m, k}, {k}, {m}})) {
510 if (auto reductionSize = getReductionSize(lhsType, 1)) {
511 Value tLhs = t(lhs);
512 Value tMask = t(mask);
513 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
514 }
515 }
516 // Case mat-trans-vec: ready to go.
517 if (layout({{k, m}, {k}, {m}})) {
518 if (auto reductionSize = getReductionSize(lhsType, 0)) {
519 Value tMask = t(mask);
520 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
521 }
522 }
523 // Case vec-mat: swap and transpose.
524 if (layout({{k}, {m, k}, {m}})) {
525 if (auto reductionSize = getReductionSize(lhsType, 0)) {
526 Value tRhs = t(rhs);
527 Value tMask = t(mask);
528 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
529 }
530 }
531 // Case vec-mat-trans: swap and ready to go.
532 if (layout({{k}, {k, m}, {m}})) {
533 if (auto reductionSize = getReductionSize(lhsType, 0)) {
534 Value tMask = t(mask);
535 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
536 }
537 }
538 return failure();
539 }
540
541 //
542 // One outer reduction, one inner parallel (tmatvec flavor).
543 // Mask already has the shape of the outer product.
544 //
545 FailureOr<Value> tmatvec() {
546 if (!iters({Red(), Par()}))
547 return failure();
548 AffineExpr k, m;
549 bindDims(rewriter.getContext(), k, m);
550
551 // Case mat-vec: transpose.
552 if (layout({{m, k}, {k}, {m}}))
553 if (auto reductionSize = getReductionSize(lhsType, 1))
554 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
555 // Case mat-trans-vec: ready to go.
556 if (layout({{k, m}, {k}, {m}}))
557 if (auto reductionSize = getReductionSize(lhsType, 0))
558 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
559 // Case vec-mat: swap and transpose.
560 if (layout({{k}, {m, k}, {m}}))
561 if (auto reductionSize = getReductionSize(lhsType, 0))
562 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
563 // Case vec-mat-trans: swap and ready to go.
564 if (layout({{k}, {k, m}, {m}}))
565 if (auto reductionSize = getReductionSize(lhsType, 0))
566 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
567 return failure();
568 }
569
570private:
571 vector::CombiningKind kind;
572 Value lhs, rhs, res, mask;
573 VectorType lhsType;
574};
575
576/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
577/// semantics to a reduction_size-unrolled sequence:
578/// ```
579/// %at = vector.transpose %a, [1, 0]
580/// %bRow0 = vector.extract %b[0]
581/// %atRow0 = vector.extract %at[0]
582/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
583/// ...
584/// %bRowK = vector.extract %b[K]
585/// %atRowK = vector.extract %at[K]
586/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
587/// ```
588///
589/// This only kicks in when vectorContractLowering is set to OuterProduct but
590/// otherwise supports any layout permutation of the matrix-multiply.
591FailureOr<Value>
592ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
593 vector::ContractionOp op, MaskingOpInterface maskOp,
594 PatternRewriter &rewriter) const {
595 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
596 return failure();
597
598 if (failed(filter(op)))
599 return failure();
600
601 UnrolledOuterProductGenerator e(rewriter, op);
602 FailureOr<Value> matmatRes = e.matmat();
603 if (succeeded(matmatRes)) {
604 return matmatRes;
605 }
606 FailureOr<Value> matvecRes = e.matvec();
607 if (succeeded(matvecRes)) {
608 return matvecRes;
609 }
610
611 FailureOr<Value> tmatvecRes = e.tmatvec();
612 return tmatvecRes;
613}
614
615FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
616 vector::ContractionOp op, MaskingOpInterface maskOp,
617 PatternRewriter &rewriter) const {
618 // TODO: Support vector.mask.
619 if (maskOp)
620 return failure();
621
622 if (failed(filter(op)))
623 return failure();
624
625 if (vectorContractLowering != vector::VectorContractLowering::Dot)
626 return failure();
627
628 auto iteratorTypes = op.getIteratorTypes().getValue();
629 static constexpr std::array<int64_t, 2> perm = {1, 0};
630 Location loc = op.getLoc();
631 Value lhs = op.getLhs(), rhs = op.getRhs();
632
633 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
634 auto infer = [&](MapList m) {
635 return AffineMap::inferFromExprList(m, op.getContext());
636 };
637 AffineExpr m, n, k;
638 bindDims(rewriter.getContext(), m, n, k);
639 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
640 //
641 // In the following we wish to make the reduction dimension innermost so we
642 // can load vectors and just fmul + reduce into a scalar.
643 //
644 if (isParallelIterator(iteratorTypes[0]) &&
645 isParallelIterator(iteratorTypes[1]) &&
646 isReductionIterator(iteratorTypes[2])) {
647 //
648 // Two outer parallel, one inner reduction (matmat flavor).
649 //
650 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
651 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
652 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
653 // No need to permute anything.
654 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
655 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
656 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
657 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
658 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
659 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
660 // This is the classical row-major matmul. Just permute the lhs.
661 Value tmp = lhs;
662 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
663 rhs = tmp;
664 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
665 std::swap(lhs, rhs);
666 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
667 Value tmp = lhs;
668 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
669 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
670 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
671 Value tmp = rhs;
672 rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
673 lhs = tmp;
674 } else {
675 return failure();
676 }
677 } else if (isParallelIterator(iteratorTypes[0]) &&
678 isReductionIterator(iteratorTypes[1])) {
679 //
680 // One outer parallel, one inner reduction (matvec flavor)
681 //
682 if (maps == infer({{m, n}, {n}, {m}})) {
683 // No need to permute anything.
684 } else if (maps == infer({{n, m}, {n}, {m}})) {
685 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
686 } else if (maps == infer({{n}, {m, n}, {m}})) {
687 std::swap(lhs, rhs);
688 } else if (maps == infer({{n}, {n, m}, {m}})) {
689 std::swap(lhs, rhs);
690 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
691 } else {
692 return failure();
693 }
694 } else {
695 return failure();
696 }
697
698 VectorType dstType = cast<VectorType>(op.getResultType());
699 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
700 "Expected dst type of rank 1 or 2");
701
702 unsigned rank = dstType.getRank();
703 unsigned dstRows = dstType.getShape()[0];
704 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
705
706 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
707 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
708 rewriter.getZeroAttr(dstType));
709 bool isInt = isa<IntegerType>(dstType.getElementType());
710 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
711 llvm::SmallVector<Value> extractedCols;
712 extractedCols.reserve(dstColumns);
713 for (unsigned r = 0; r < dstRows; ++r) {
714 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
715 for (unsigned c = 0; c < dstColumns; ++c) {
716 // Extract each respective row and column of the LHS and RHS once to
717 // avoid having duplicate SSA values pointing to the same rows/columns.
718 if (r == 0) {
719 Value colRhs =
720 rank == 1
721 ? rhs
722 : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
723 extractedCols.push_back(colRhs);
724 }
725 Value extractedColRhs = extractedCols[c];
726 Value product =
727 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
728 Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
729 vector::CombiningKind::ADD,
730 product, op.getFastmath());
731
734 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
735 }
736 }
737 if (auto acc = op.getAcc())
738 res = createAdd(op.getLoc(), res, acc, isInt, rewriter, fmf);
739 return res;
740}
741
742/// Lower vector.contract with all size one reduction dimensions to
743/// elementwise ops when possible.
745 : public MaskableOpRewritePattern<vector::ContractionOp> {
746 using MaskableOpRewritePattern::MaskableOpRewritePattern;
748 std::function<LogicalResult(vector::ContractionOp op)>;
749 static LogicalResult defaultFilter(vector::ContractionOp op) {
750 return success();
751 }
753 vector::VectorContractLowering vectorContractLowering,
754 MLIRContext *context, PatternBenefit benefit = 1,
755 const FilterConstraintType &constraint = defaultFilter)
756 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
757 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
758
759 FailureOr<Value>
760 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
761 MaskingOpInterface maskOp,
762 PatternRewriter &rewriter) const override {
763 // TODO: Support vector.mask.
764 if (maskOp)
765 return failure();
766
767 if (failed(filter(contractOp)))
768 return failure();
769
770 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
771 return failure();
772
773 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
774 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
775 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
776 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
777 SmallVector<int64_t> lhsReductionDims =
778 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
779 SmallVector<int64_t> rhsReductionDims =
780 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
781 // All the reduction dimensions must be a size 1.
782 for (int64_t dim : lhsReductionDims) {
783 if (lhsShape[dim] != 1)
784 return failure();
785 }
786 for (int64_t dim : rhsReductionDims) {
787 if (rhsShape[dim] != 1)
788 return failure();
789 }
790 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
791 unsigned numParallelDims = accMap.getNumResults();
792 unsigned numLhsDimToBroadcast =
793 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
794 unsigned numRhsDimToBroadcast =
795 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
796 SmallVector<int64_t> lhsDims;
797 SmallVector<int64_t> lhsTranspose;
798 SmallVector<int64_t> rhsDims;
799 SmallVector<int64_t> rhsTranspose;
800 for (int64_t dim : lhsReductionDims)
801 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
802 for (int64_t dim : rhsReductionDims)
803 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
804 // Loop through the parallel dimensions to calculate the dimensions to
805 // broadcast and to permute in order to extract only parallel dimensions.
806 for (unsigned i = 0; i < numParallelDims; i++) {
807 std::optional<unsigned> lhsDim =
808 getDimPosition(lhsMap, accMap.getDimPosition(i));
809 if (lhsDim) {
810 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
811 } else {
812 // If the parallel dimension doesn't exist we will have to broadcast it.
813 lhsDims.push_back(
814 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
815 lhsTranspose.push_back(lhsDims.size() - 1);
816 }
817 std::optional<unsigned> rhsDim =
818 getDimPosition(rhsMap, accMap.getDimPosition(i));
819 if (rhsDim) {
820 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
821 } else {
822 // If the parallel dimension doesn't exist we will have to broadcast it.
823 rhsDims.push_back(
824 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
825 rhsTranspose.push_back(rhsDims.size() - 1);
826 }
827 }
828 Value newLhs = contractOp.getLhs();
829 Value newRhs = contractOp.getRhs();
830 Location loc = contractOp.getLoc();
831 if (!lhsDims.empty()) {
832 lhsDims.append(lhsShape.begin(), lhsShape.end());
833 auto expandedType =
834 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
835 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
836 }
837 if (!rhsDims.empty()) {
838 rhsDims.append(rhsShape.begin(), rhsShape.end());
839 auto expandedType =
840 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
841 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
842 }
843 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
844 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
845 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
846 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
847 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
848 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
849 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
850 std::optional<Value> result =
851 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
852 contractOp.getKind(), rewriter, isInt,
853 /*mask=*/Value(), contractOp.getFastmathAttr());
854 if (result)
855 return *result;
856
857 return failure();
858 }
859
860private:
861 /// Options to control the vector patterns.
862 vector::VectorContractLowering vectorContractLowering;
864};
865
866/// Progressive lowering of ContractionOp.
867/// One:
868/// %x = vector.contract with at least one free/batch dimension
869/// is replaced by:
870/// %a = vector.contract with one less free/batch dimension
871/// %b = vector.contract with one less free/batch dimension
872/// ..
873/// %x = combine %a %b ..
874/// until a pure contraction is reached (no free/batch dimensions),
875/// which is replaced by a dot-product.
876///
877/// This only kicks in when either vectorContractLoweringOption is set
878/// to DOT or when other contraction patterns fail.
879//
880// TODO: break down into transpose/reshape/cast ops
881// when they become available to avoid code dup
882// TODO: investigate lowering order impact on performance
883FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
884 vector::ContractionOp op, MaskingOpInterface maskOp,
885 PatternRewriter &rewriter) const {
886 if (failed(filter(op)))
887 return failure();
888
889 // TODO: support mixed mode contract lowering.
890 if (op.getLhsType().getElementType() !=
891 getElementTypeOrSelf(op.getAccType()) ||
892 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
893 return failure();
894
895 // TODO: the code below assumes the default contraction, make sure it supports
896 // other kinds before enabling this lowering.
897 if (op.getKind() != vector::CombiningKind::ADD) {
898 return rewriter.notifyMatchFailure(
899 op, "contractions other than 'add' not supported");
900 }
901
902 // TODO: implement benefits, cost models.
903 MLIRContext *ctx = op.getContext();
904
905 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
906 FailureOr<Value> newVal1 =
907 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
908 if (!failed(newVal1))
909 return newVal1;
910
911 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
912 FailureOr<Value> newVal2 =
913 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
914 if (!failed(newVal2))
915 return newVal2;
916
917 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
918 FailureOr<Value> newVal4 =
919 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
920 if (!failed(newVal4))
921 return newVal4;
922
923 // Vector mask setup.
924
925 Value mask;
926 if (maskOp)
927 mask = maskOp.getMask();
928 // Find first batch dimension in LHS/RHS, and lower when found.
929 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
930 if (!batchDimMap.empty()) {
931 int64_t lhsIndex = batchDimMap[0].first;
932 int64_t rhsIndex = batchDimMap[0].second;
933 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
934 if (failed(newOp))
935 return failure();
936 return newOp;
937 }
938
939 // Collect contracting dimensions.
940 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
941 op.getContractingDimMap();
942 DenseSet<int64_t> lhsContractingDimSet;
943 DenseSet<int64_t> rhsContractingDimSet;
944 for (auto &dimPair : contractingDimMap) {
945 lhsContractingDimSet.insert(dimPair.first);
946 rhsContractingDimSet.insert(dimPair.second);
947 }
948
949 // Find first free dimension in LHS, and lower when found.
950 VectorType lhsType = op.getLhsType();
951 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
952 if (lhsContractingDimSet.count(lhsIndex) == 0) {
953 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
954 if (failed(newOp))
955 return failure();
956 return newOp;
957 }
958 }
959
960 // Find first free dimension in RHS, and lower when found.
961 VectorType rhsType = op.getRhsType();
962 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
963 if (rhsContractingDimSet.count(rhsIndex) == 0) {
964 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
965 if (failed(newOp))
966 return failure();
967 return newOp;
968 }
969 }
970
971 // Lower the first remaining reduction dimension.
972 if (!contractingDimMap.empty()) {
973 auto newOp = lowerReduction(rewriter, op, mask);
974 if (failed(newOp))
975 return failure();
976 return newOp;
977 }
978
979 return failure();
980}
981
982// Lower one parallel dimension.
983// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
984// TODO: consider reusing existing contract unrolling
985FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
986 vector::ContractionOp op,
987 int64_t lhsIndex,
988 int64_t rhsIndex,
989 Value mask) const {
990 VectorType lhsType = op.getLhsType();
991 VectorType rhsType = op.getRhsType();
992 VectorType resType = cast<VectorType>(op.getResultType());
993 // Find the iterator type index and result index.
994 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
995 int64_t iterIndex = -1;
996 int64_t dimSize = -1;
997 if (lhsIndex >= 0) {
998 iterIndex = iMap[0].getDimPosition(lhsIndex);
999 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1000 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1001 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1002 << " to map to the same dimension";
1003 });
1004 if (lhsType.getScalableDims()[lhsIndex])
1005 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1006 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1007 << ") is not supported yet";
1008 });
1009 dimSize = lhsType.getDimSize(lhsIndex);
1010 } else if (rhsIndex >= 0) {
1011 iterIndex = iMap[1].getDimPosition(rhsIndex);
1012 if (rhsType.getScalableDims()[rhsIndex])
1013 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1014 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1015 << ") is not supported yet";
1016 });
1017 dimSize = rhsType.getDimSize(rhsIndex);
1018 }
1019 if (iterIndex < 0)
1020 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1021 diag << "expected either lhsIndex=" << lhsIndex
1022 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1023 });
1024 // value_or(-1) means that we tolerate a dimension not appearing
1025 // in the result map. That can't happen for actual parallel iterators, but
1026 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1027 // lowerParallel also for the case of unit-size reduction dims appearing only
1028 // on one of LHS or RHS, not both. At the moment, such cases are created by
1029 // CastAwayContractionLeadingOneDim, so we need to either support that or
1030 // modify that pattern.
1031 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1032 if (resIndex == -1 && dimSize != 1)
1033 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1034 diag << "expected the dimension for iterIndex=" << iterIndex
1035 << " to either appear in the result map, or to be a unit dimension";
1036 });
1037
1038 // Construct new iterator types and affine map array attribute.
1039 std::array<AffineMap, 3> lowIndexingMaps = {
1040 adjustMap(iMap[0], iterIndex, rewriter),
1041 adjustMap(iMap[1], iterIndex, rewriter),
1042 adjustMap(iMap[2], iterIndex, rewriter)};
1043 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1044 auto lowIter =
1045 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1046 // Unroll into a series of lower dimensional vector.contract ops.
1047 Location loc = op.getLoc();
1048 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1049 rewriter.getZeroAttr(resType));
1050
1051 for (int64_t d = 0; d < dimSize; ++d) {
1052 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1053 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1054 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1055
1056 Value lowMask;
1057 if (mask)
1058 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1059 iterIndex, d, rewriter);
1060
1061 Operation *lowContract =
1062 vector::ContractionOp::create(rewriter, loc, lhs, rhs, acc, lowAffine,
1063 lowIter, op.getKind(), op.getFastmath());
1064 lowContract = maskOperation(rewriter, lowContract, lowMask);
1065 result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1066 resIndex, d, rewriter);
1067 }
1068 return result;
1069}
1070
1071// Lower one reduction dimension.
1072FailureOr<Value> ContractionOpLowering::lowerReduction(
1073 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1074 auto loc = op.getLoc();
1075 VectorType lhsType = op.getLhsType();
1076 VectorType rhsType = op.getRhsType();
1077 Type resType = op.getResultType();
1078 if (isa<VectorType>(resType))
1079 return rewriter.notifyMatchFailure(op,
1080 "did not expect a VectorType result");
1081 bool isInt = isa<IntegerType>(resType);
1082 // Use iterator index 0.
1083 int64_t iterIndex = 0;
1084 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1085 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1086 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1087 if (!lookupLhs.has_value())
1088 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1089 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1090 });
1091 if (!lookupRhs.has_value())
1092 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1093 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1094 });
1095 int64_t lhsIndex = *lookupLhs;
1096 int64_t rhsIndex = *lookupRhs;
1097 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1098 if (dimSize != rhsType.getDimSize(rhsIndex))
1099 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1100 diag << "expect LHS dimension " << lhsIndex
1101 << " to have the same size as RHS dimension " << rhsIndex;
1102 });
1103 // Base case.
1104 if (lhsType.getRank() == 1) {
1105 if (rhsType.getRank() != 1)
1106 return rewriter.notifyMatchFailure(
1107 op, "When LHS has rank 1, expected also RHS to have rank 1");
1108 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
1109 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
1110 auto kind = vector::CombiningKind::ADD;
1111
1112 Value acc = op.getAcc();
1113 Operation *reductionOp =
1114 acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc,
1115 op.getFastmath())
1116 : vector::ReductionOp::create(rewriter, loc, kind, m,
1117 op.getFastmath());
1118 return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1119 }
1120 // Construct new iterator types and affine map array attribute.
1121 std::array<AffineMap, 3> lowIndexingMaps = {
1122 adjustMap(iMap[0], iterIndex, rewriter),
1123 adjustMap(iMap[1], iterIndex, rewriter),
1124 adjustMap(iMap[2], iterIndex, rewriter)};
1125 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1126 auto lowIter =
1127 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1128 // Unroll into a series of lower dimensional vector.contract ops.
1129 // By feeding the initial accumulator into the first contraction,
1130 // and the result of each contraction into the next, eventually
1131 // the sum of all reductions is computed.
1132 Value result = op.getAcc();
1133 for (int64_t d = 0; d < dimSize; ++d) {
1134 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1135 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1136 Value newMask;
1137 if (mask)
1138 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1139 iterIndex, d, rewriter);
1140
1141 Operation *newContract = vector::ContractionOp::create(
1142 rewriter, loc, lhs, rhs, result, lowAffine, lowIter, op.getKind(),
1143 op.getFastmath());
1144 result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1145 }
1146 return result;
1147}
1148
1149/// Progressive lowering of OuterProductOp.
1150/// One:
1151/// %x = vector.outerproduct %lhs, %rhs, %acc
1152/// is replaced by:
1153/// %z = zero-result
1154/// %0 = vector.extract %lhs[0]
1155/// %1 = vector.broadcast %0
1156/// %2 = vector.extract %acc[0]
1157/// %3 = vector.fma %1, %rhs, %2
1158/// %4 = vector.insert %3, %z[0]
1159/// ..
1160/// %x = vector.insert %.., %..[N-1]
1161///
1162class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1163public:
1164 using Base::Base;
1165
1166 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1167 PatternRewriter &rewriter) const override {
1168 VectorType resType = op.getResultVectorType();
1169 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1170 return failure();
1171
1172 auto loc = op.getLoc();
1173
1174 VectorType lhsType = op.getOperandVectorTypeLHS();
1175 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1176 Type eltType = resType.getElementType();
1177 bool isInt = isa<IntegerType, IndexType>(eltType);
1178 Value acc = op.getAcc();
1179 vector::CombiningKind kind = op.getKind();
1180
1181 // Vector mask setup.
1182 OpBuilder::InsertionGuard guard(rewriter);
1183 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1184 Operation *rootOp;
1185 Value mask;
1186 if (maskableOp.isMasked()) {
1187 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1188 rootOp = maskableOp.getMaskingOp();
1189 mask = maskableOp.getMaskingOp().getMask();
1190 } else {
1191 rootOp = op;
1192 }
1193
1194 if (!rhsType) {
1195 // Special case: AXPY operation.
1196 Value b =
1197 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1198 std::optional<Value> mult = createContractArithOp(
1199 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1200 if (!mult.has_value())
1201 return failure();
1202 rewriter.replaceOp(rootOp, *mult);
1203 return success();
1204 }
1205
1206 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1207 rewriter.getZeroAttr(resType));
1208 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1209 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1210 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1211 Value r = nullptr;
1212 if (acc)
1213 r = vector::ExtractOp::create(rewriter, loc, acc, d);
1214 Value extrMask;
1215 if (mask)
1216 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1217
1218 std::optional<Value> m = createContractArithOp(
1219 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1220 if (!m.has_value())
1221 return failure();
1222 result = vector::InsertOp::create(rewriter, loc, *m, result, d);
1223 }
1224
1225 rewriter.replaceOp(rootOp, result);
1226 return success();
1227 }
1228};
1229
1230} // namespace
1231
1233 RewritePatternSet &patterns,
1234 VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1235 bool disableOuterProductLowering) {
1236 if (!disableOuterProductLowering)
1237 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1238 patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1239 vectorContractLoweringOption, patterns.getContext(), benefit);
1240}
1241
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
auto load
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter, arith::FastMathFlagsAttr fmf={})
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value(), arith::FastMathFlagsAttr fmf={})
Helper to create arithmetic operation associated with a kind of contraction.
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::string diag(const llvm::Value &value)
#define mul(a, b)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:270
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:322
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:433
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0