MLIR 23.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/Location.h"
25
26#define DEBUG_TYPE "vector-contract-lowering"
27
28using namespace mlir;
29using namespace mlir::vector;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34// Helper to find an index in an affine map.
35static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
36 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
37 int64_t idx = map.getDimPosition(i);
38 if (idx == index)
39 return i;
40 }
41 return std::nullopt;
42}
43
44// Helper to construct iterator types with one index removed.
46 int64_t index) {
48 for (const auto &it : llvm::enumerate(iteratorTypes)) {
49 int64_t idx = it.index();
50 if (idx == index)
51 continue;
52 results.push_back(it.value());
53 }
54 return results;
55}
56
57// Helper to construct an affine map with one index removed.
59 PatternRewriter &rewriter) {
60 auto *ctx = rewriter.getContext();
62 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
63 int64_t idx = map.getDimPosition(i);
64 if (idx == index)
65 continue;
66 // Re-insert remaining indices, but renamed when occurring
67 // after the removed index.
68 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
69 results.push_back(targetExpr);
70 }
71 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
72}
73
74/// Returns `val` with the dimension at position `index` dropped by indexing
75/// that dimension with `pos`.
76///
77/// If `index == -1`, returns `val` unchanged. If `index == 0`, the result is
78/// a single `vector.extract %val[pos]`.
79///
80/// Example (`index == 0`): extract the sub-vector at `pos` along the leading
81/// dimension.
82/// // val : vector<4x8xf32>, pos = 2
83/// %res = vector.extract %val[2] : vector<8xf32> from vector<4x8xf32>
84///
85/// For `index > 0`, recursively applies the same drop to each sub-vector of
86/// the leading dimension and reassembles the result.
88 PatternRewriter &rewriter) {
89 if (index == -1)
90 return val;
91
92 // At extraction dimension?
93 if (index == 0)
94 return vector::ExtractOp::create(rewriter, loc, val, pos);
95
96 // Unroll leading dimensions.
97 VectorType type = cast<VectorType>(val.getType());
98 VectorType resType = VectorType::Builder(type).dropDim(index);
99 Value result = arith::ConstantOp::create(rewriter, loc, resType,
100 rewriter.getZeroAttr(resType));
101 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
102 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
103 Value load = reshapeLoad(loc, ext, index - 1, pos, rewriter);
104 result = vector::InsertOp::create(rewriter, loc, load, result, d);
105 }
106 return result;
107}
108
109/// Inserts `val` into `result` at position `pos` along dimension `index`.
110///
111/// This is the inverse of `reshapeLoad`. If `index == -1`, returns `val`. If
112/// `index == 0`, the result is a single `vector.insert %val, %result [pos]`.
113///
114/// Example (`index == 0`): insert `val` at `pos` along the leading dimension.
115/// // val : vector<4xf32>, acc : vector<2x4xf32>, pos = 1
116/// %res = vector.insert %val, %acc [1] : vector<4xf32> into vector<2x4xf32>
117///
118/// For `index > 0`, recursively applies the same insertion to each sub-vector
119/// of the leading dimension and reassembles the result.
121 int64_t pos, PatternRewriter &rewriter) {
122 // Unmodified?
123 if (index == -1)
124 return val;
125 // At insertion dimension?
126 if (index == 0)
127 return vector::InsertOp::create(rewriter, loc, val, result, pos);
128
129 // Unroll leading dimensions.
130 VectorType type = cast<VectorType>(result.getType());
131 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
132 Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
133 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
134 Value sto = reshapeStore(loc, ins, ext, index - 1, pos, rewriter);
135 result = vector::InsertOp::create(rewriter, loc, sto, result, d);
136 }
137 return result;
138}
139
140/// Helper to create arithmetic operation associated with a kind of contraction.
141static std::optional<Value>
143 vector::CombiningKind kind, PatternRewriter &rewriter,
144 bool isInt, Value mask = Value(),
145 arith::FastMathFlagsAttr fmf = {}) {
146 using vector::CombiningKind;
147 Value mul;
148
149 if (isInt) {
150 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
151 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
152 // Only valid for floating point types.
153 return std::nullopt;
154 mul = arith::MulIOp::create(rewriter, loc, x, y);
155 } else {
156 // Float case.
157 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
158 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
159 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
160 kind == CombiningKind::XOR)
161 // Only valid for integer types.
162 return std::nullopt;
163 // Special case for fused multiply-add.
164 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
165 Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
166 if (mask)
167 // The fma op doesn't need explicit masking. However, fma ops used in
168 // reductions must preserve previous 'acc' values for masked-out lanes.
169 fma = selectPassthru(rewriter, mask, fma, acc);
170 return fma;
171 }
172 mul = arith::MulFOp::create(rewriter, loc, x, y, fmf);
173 }
174
175 if (!acc)
176 return std::optional<Value>(mul);
177
178 return makeArithReduction(rewriter, loc, kind, mul, acc, fmf, mask);
179}
180
181/// Return the positions of the reductions in the given map.
183 ArrayAttr iteratorTypes) {
184 SmallVector<int64_t> dimsIdx;
185 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
186 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
187 dimsIdx.push_back(i);
188 }
189 return dimsIdx;
190}
191
192/// Look for a given dimension in an affine map and return its position. Return
193/// std::nullopt if the dimension is not in the map results.
194static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
195 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
196 if (map.getDimPosition(i) == dim)
197 return i;
198 }
199 return std::nullopt;
200}
201
202/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
203/// operands `x` and `y`.
204static Value createAdd(Location loc, Value x, Value y, bool isInt,
205 PatternRewriter &rewriter,
206 arith::FastMathFlagsAttr fmf = {}) {
207 if (isInt)
208 return arith::AddIOp::create(rewriter, loc, x, y);
209 return arith::AddFOp::create(rewriter, loc, x, y, fmf);
210}
211
212/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
213/// operands `x and `y`.
214static Value createMul(Location loc, Value x, Value y, bool isInt,
215 PatternRewriter &rewriter,
216 arith::FastMathFlagsAttr fmf = {}) {
217 if (isInt)
218 return arith::MulIOp::create(rewriter, loc, x, y);
219 return arith::MulFOp::create(rewriter, loc, x, y, fmf);
220}
221
222namespace {
223
224/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
225/// semantics to a reduction_size-unrolled sequence:
226/// ```
227/// %at = vector.transpose %a, [1, 0]
228/// %bRow0 = vector.extract %b[0]
229/// %atRow0 = vector.extract %at[0]
230/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
231/// ...
232/// %bRowK = vector.extract %b[K]
233/// %atRowK = vector.extract %at[K]
234/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
235/// ```
236///
237/// This only kicks in when vectorContractLowering is set to OuterProduct and
238/// the vector.contract op is a row-major matrix multiply.
239class ContractionOpToOuterProductOpLowering
240 : public MaskableOpRewritePattern<vector::ContractionOp> {
241public:
242 using MaskableOpRewritePattern::MaskableOpRewritePattern;
243
244 using FilterConstraintType =
245 std::function<LogicalResult(vector::ContractionOp op)>;
246
247 static LogicalResult defaultFilter(vector::ContractionOp op) {
248 return success();
249 }
250
251 ContractionOpToOuterProductOpLowering(
252 vector::VectorContractLowering vectorContractLowering,
253 MLIRContext *context, PatternBenefit benefit = 1,
254 FilterConstraintType constraint = defaultFilter)
255 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
256 vectorContractLowering(vectorContractLowering),
257 filter(std::move(constraint)) {}
258
259 FailureOr<Value>
260 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
261 PatternRewriter &rewriter) const override;
262
263private:
264 /// Options to control the vector patterns.
265 vector::VectorContractLowering vectorContractLowering;
266 FilterConstraintType filter;
267};
268
269/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
270/// semantics to an output-size-unrolled sequence:
271/// ```
272/// %out = arith.constant ... : vector<MxNxelt_type>
273/// %bt = vector.transpose %b, [1, 0]
274/// %aRow0 = vector.extract %a[0]
275/// %btRow0 = vector.extract %bt[0]
276/// %c00 = vector.reduction %atRow0, %bRow0
277/// %out00 = vector.insert %c00, %out[0, 0]
278/// ...
279/// %aRowLast = vector.extract %at[M-1]
280/// %btRowLast = vector.extract %b[N-1]
281/// %cLastLast = vector.reduction %atRowLast, %bRowLast
282/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
283/// ```
284///
285/// This only kicks in when VectorTransformsOptions is set to Dot and
286/// the vector.contract op is a row-major matmul or matvec.
287class ContractionOpToDotLowering
288 : public MaskableOpRewritePattern<vector::ContractionOp> {
289public:
290 using MaskableOpRewritePattern::MaskableOpRewritePattern;
291
292 using FilterConstraintType =
293 std::function<LogicalResult(vector::ContractionOp op)>;
294
295 static LogicalResult defaultFilter(vector::ContractionOp op) {
296 return success();
297 }
298
299 ContractionOpToDotLowering(
300 vector::VectorContractLowering vectorContractLowering,
301 MLIRContext *context, PatternBenefit benefit = 1,
302 const FilterConstraintType &constraint = defaultFilter)
303 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
304 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
305
306 FailureOr<Value>
307 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
308 PatternRewriter &rewriter) const override;
309
310private:
311 /// Options to control the vector patterns.
312 vector::VectorContractLowering vectorContractLowering;
313 FilterConstraintType filter;
314};
315
316/// Progressive lowering of ContractionOp.
317///
318/// One:
319/// %x = vector.contract with at least one free/batch dimension
320/// is replaced by:
321/// %a = vector.contract with one less free/batch dimension
322/// %b = vector.contract with one less free/batch dimension
323/// ..
324/// %x = combine %a %b ..
325/// until a pure contraction is reached (no free/batch dimensions),
326/// which is replaced by a dot-product.
327///
328/// This only kicks in when either VectorTransformsOptions is set
329/// to Dot or when other contraction patterns fail.
330class ContractionOpLowering
331 : public MaskableOpRewritePattern<vector::ContractionOp> {
332public:
333 using MaskableOpRewritePattern::MaskableOpRewritePattern;
334 using FilterConstraintType =
335 std::function<LogicalResult(vector::ContractionOp op)>;
336
337 static LogicalResult defaultFilter(vector::ContractionOp op) {
338 return success();
339 }
340
341 ContractionOpLowering(
342 vector::VectorContractLowering vectorContractLoweringOption,
343 MLIRContext *context, PatternBenefit benefit = 1,
344 FilterConstraintType constraint = defaultFilter)
345 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
346 vectorContractLoweringOption(vectorContractLoweringOption),
347 filter(std::move(constraint)) {}
348
349 FailureOr<Value>
350 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
351 PatternRewriter &rewriter) const override;
352
353private:
354 /// Options to control the vector patterns.
355 vector::VectorContractLowering vectorContractLoweringOption;
356 FilterConstraintType filter;
357 // Lower one parallel dimension.
358 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
359 vector::ContractionOp op, int64_t lhsIndex,
360 int64_t rhsIndex, Value mask) const;
361 // Lower one reduction dimension.
362 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
363 vector::ContractionOp op, Value mask) const;
364};
365
366/// Generate a vector implementation for matmat, matvec and tmatvec.
367/// This unrolls outer-products along the reduction dimension.
368struct UnrolledOuterProductGenerator
369 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
370 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
371 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
372 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
373 res(op.getAcc()), lhsType(op.getLhsType()) {
374 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
375 if (maskableOp.isMasked())
376 mask = maskableOp.getMaskingOp().getMask();
377 }
378
379 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
380 if (!v)
381 return v;
382 return vector::TransposeOp::create(rewriter, loc, v, perm);
383 }
384
385 Value promote(Value v, Type dstElementType) {
386 Type elementType = v.getType();
387 auto vecType = dyn_cast<VectorType>(elementType);
388 if (vecType)
389 elementType = vecType.getElementType();
390 if (elementType == dstElementType)
391 return v;
392 Type promotedType = dstElementType;
393 if (vecType)
394 promotedType = vecType.clone(promotedType);
395 if (isa<FloatType>(dstElementType))
396 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
397 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
398 }
399
400 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
401 VectorType lhsType, int reductionSize,
402 std::optional<Value> maybeMask = std::nullopt) {
403 // Incremental support for masking.
404 if (mask && !maybeMask.has_value())
405 return failure();
406
407 Type resElementType = cast<VectorType>(res.getType()).getElementType();
408 for (int64_t k = 0; k < reductionSize; ++k) {
409 Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
410 Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
411 extractA = promote(extractA, resElementType);
412 extractB = promote(extractB, resElementType);
413 Value extractMask;
414 if (maybeMask.has_value() && maybeMask.value())
415 extractMask =
416 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
417
418 Operation *outerProdOp = vector::OuterProductOp::create(
419 rewriter, loc, res.getType(), extractA, extractB, res, kind);
420 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
421 }
422 return res;
423 }
424
425 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
426 /// dimension `reductionDim`. If the dimension is a scalable dimension,
427 /// returns "nullopt".
428 std::optional<int64_t> getReductionSize(VectorType vecType,
429 int64_t reductionDim) {
430 // Cannot unroll scalable dimension.
431 if (vecType.getScalableDims()[reductionDim])
432 return std::nullopt;
433 int64_t reductionSize = vecType.getDimSize(reductionDim);
434 assert(reductionSize > 0 &&
435 "Reduction dim must be a known static size to allow unrolling");
436 return reductionSize;
437 }
438
439 /// Two outer parallel, one inner reduction (matmat flavor).
440 FailureOr<Value> matmat() {
441 if (!iters({Par(), Par(), Red()}))
442 return failure();
443 // Set up the parallel/reduction structure in the right form.
444 AffineExpr m, n, k;
445 bindDims(rewriter.getContext(), m, n, k);
446
447 // Classical row-major matmul: Just permute the lhs.
448 if (layout({{m, k}, {k, n}, {m, n}})) {
449 if (auto reductionSize = getReductionSize(lhsType, 1)) {
450 // Note: `t` creates new IR. It must be nested within this `if` check
451 // so that no IR is created when then pattern returns "failure".
452 Value tLhs = t(lhs);
453 Value tMask = t(mask, {2, 0, 1});
454 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
455 }
456 }
457 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
458 if (layout({{m, k}, {n, k}, {m, n}})) {
459 if (auto reductionSize = getReductionSize(lhsType, 1)) {
460 Value tLhs = t(lhs);
461 Value tRhs = t(rhs);
462 Value tMask = t(mask, {2, 0, 1});
463 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
464 }
465 }
466 // No need to permute anything.
467 if (layout({{k, m}, {k, n}, {m, n}})) {
468 if (auto reductionSize = getReductionSize(lhsType, 0)) {
469 Value tMask = t(mask, {2, 0, 1});
470 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
471 }
472 }
473 // Just permute the rhs.
474 if (layout({{k, m}, {n, k}, {m, n}})) {
475 if (auto reductionSize = getReductionSize(lhsType, 0)) {
476 Value tRhs = t(rhs);
477 Value tMask = t(mask, {2, 0, 1});
478 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
479 }
480 }
481 // Transposed output: swap RHS and LHS.
482 // Classical row-major matmul: permute the lhs.
483 if (layout({{m, k}, {k, n}, {n, m}})) {
484 if (auto reductionSize = getReductionSize(lhsType, 1)) {
485 Value tLhs = t(lhs);
486 Value tMask = t(mask, {2, 0, 1});
487 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
488 }
489 }
490 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
491 if (layout({{m, k}, {n, k}, {n, m}})) {
492 if (auto reductionSize = getReductionSize(lhsType, 1)) {
493 Value tRhs = t(rhs);
494 Value tLhs = t(lhs);
495 Value tMask = t(mask, {2, 0, 1});
496 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
497 }
498 }
499 if (layout({{k, m}, {k, n}, {n, m}})) {
500 if (auto reductionSize = getReductionSize(lhsType, 0)) {
501 Value tMask = t(mask, {2, 0, 1});
502 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
503 }
504 }
505 if (layout({{k, m}, {n, k}, {n, m}})) {
506 if (auto reductionSize = getReductionSize(lhsType, 0)) {
507 Value tRhs = t(rhs);
508 Value tMask = t(mask, {2, 0, 1});
509 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
510 }
511 }
512 return failure();
513 }
514
515 //
516 // One outer parallel, one inner reduction (matvec flavor).
517 // Mask needs to be transposed everywhere to turn the reduction dimension
518 // outermost as required by outerproduct.
519 //
520 FailureOr<Value> matvec() {
521 if (!iters({Par(), Red()}))
522 return failure();
523 AffineExpr m, k;
524 bindDims(rewriter.getContext(), m, k);
525
526 // Case mat-vec: transpose.
527 if (layout({{m, k}, {k}, {m}})) {
528 if (auto reductionSize = getReductionSize(lhsType, 1)) {
529 Value tLhs = t(lhs);
530 Value tMask = t(mask);
531 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
532 }
533 }
534 // Case mat-trans-vec: ready to go.
535 if (layout({{k, m}, {k}, {m}})) {
536 if (auto reductionSize = getReductionSize(lhsType, 0)) {
537 Value tMask = t(mask);
538 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
539 }
540 }
541 // Case vec-mat: swap and transpose.
542 if (layout({{k}, {m, k}, {m}})) {
543 if (auto reductionSize = getReductionSize(lhsType, 0)) {
544 Value tRhs = t(rhs);
545 Value tMask = t(mask);
546 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
547 }
548 }
549 // Case vec-mat-trans: swap and ready to go.
550 if (layout({{k}, {k, m}, {m}})) {
551 if (auto reductionSize = getReductionSize(lhsType, 0)) {
552 Value tMask = t(mask);
553 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
554 }
555 }
556 return failure();
557 }
558
559 //
560 // One outer reduction, one inner parallel (tmatvec flavor).
561 // Mask already has the shape of the outer product.
562 //
563 FailureOr<Value> tmatvec() {
564 if (!iters({Red(), Par()}))
565 return failure();
566 AffineExpr k, m;
567 bindDims(rewriter.getContext(), k, m);
568
569 // Case mat-vec: transpose.
570 if (layout({{m, k}, {k}, {m}}))
571 if (auto reductionSize = getReductionSize(lhsType, 1))
572 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
573 // Case mat-trans-vec: ready to go.
574 if (layout({{k, m}, {k}, {m}}))
575 if (auto reductionSize = getReductionSize(lhsType, 0))
576 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
577 // Case vec-mat: swap and transpose.
578 if (layout({{k}, {m, k}, {m}}))
579 if (auto reductionSize = getReductionSize(lhsType, 0))
580 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
581 // Case vec-mat-trans: swap and ready to go.
582 if (layout({{k}, {k, m}, {m}}))
583 if (auto reductionSize = getReductionSize(lhsType, 0))
584 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
585 return failure();
586 }
587
588private:
589 vector::CombiningKind kind;
590 Value lhs, rhs, res, mask;
591 VectorType lhsType;
592};
593
594/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
595/// semantics to a reduction_size-unrolled sequence:
596/// ```
597/// %at = vector.transpose %a, [1, 0]
598/// %bRow0 = vector.extract %b[0]
599/// %atRow0 = vector.extract %at[0]
600/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
601/// ...
602/// %bRowK = vector.extract %b[K]
603/// %atRowK = vector.extract %at[K]
604/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
605/// ```
606///
607/// This only kicks in when vectorContractLowering is set to OuterProduct but
608/// otherwise supports any layout permutation of the matrix-multiply.
609FailureOr<Value>
610ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
611 vector::ContractionOp op, MaskingOpInterface maskOp,
612 PatternRewriter &rewriter) const {
613 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
614 return failure();
615
616 if (failed(filter(op)))
617 return failure();
618
619 UnrolledOuterProductGenerator e(rewriter, op);
620 FailureOr<Value> matmatRes = e.matmat();
621 if (succeeded(matmatRes)) {
622 return matmatRes;
623 }
624 FailureOr<Value> matvecRes = e.matvec();
625 if (succeeded(matvecRes)) {
626 return matvecRes;
627 }
628
629 FailureOr<Value> tmatvecRes = e.tmatvec();
630 return tmatvecRes;
631}
632
633FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
634 vector::ContractionOp op, MaskingOpInterface maskOp,
635 PatternRewriter &rewriter) const {
636 // TODO: Support vector.mask.
637 if (maskOp)
638 return failure();
639
640 if (failed(filter(op)))
641 return failure();
642
643 if (vectorContractLowering != vector::VectorContractLowering::Dot)
644 return failure();
645
646 auto iteratorTypes = op.getIteratorTypes().getValue();
647 static constexpr std::array<int64_t, 2> perm = {1, 0};
648 Location loc = op.getLoc();
649 Value lhs = op.getLhs(), rhs = op.getRhs();
650
651 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
652 auto infer = [&](MapList m) {
653 return AffineMap::inferFromExprList(m, op.getContext());
654 };
655 AffineExpr m, n, k;
656 bindDims(rewriter.getContext(), m, n, k);
657 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
658 //
659 // In the following we wish to make the reduction dimension innermost so we
660 // can load vectors and just fmul + reduce into a scalar.
661 //
662 if (isParallelIterator(iteratorTypes[0]) &&
663 isParallelIterator(iteratorTypes[1]) &&
664 isReductionIterator(iteratorTypes[2])) {
665 //
666 // Two outer parallel, one inner reduction (matmat flavor).
667 //
668 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
669 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
670 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
671 // No need to permute anything.
672 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
673 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
674 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
675 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
676 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
677 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
678 // This is the classical row-major matmul. Just permute the lhs.
679 Value tmp = lhs;
680 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
681 rhs = tmp;
682 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
683 std::swap(lhs, rhs);
684 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
685 Value tmp = lhs;
686 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
687 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
688 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
689 Value tmp = rhs;
690 rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
691 lhs = tmp;
692 } else {
693 return failure();
694 }
695 } else if (isParallelIterator(iteratorTypes[0]) &&
696 isReductionIterator(iteratorTypes[1])) {
697 //
698 // One outer parallel, one inner reduction (matvec flavor)
699 //
700 if (maps == infer({{m, n}, {n}, {m}})) {
701 // No need to permute anything.
702 } else if (maps == infer({{n, m}, {n}, {m}})) {
703 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
704 } else if (maps == infer({{n}, {m, n}, {m}})) {
705 std::swap(lhs, rhs);
706 } else if (maps == infer({{n}, {n, m}, {m}})) {
707 std::swap(lhs, rhs);
708 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
709 } else {
710 return failure();
711 }
712 } else {
713 return failure();
714 }
715
716 VectorType dstType = cast<VectorType>(op.getResultType());
717 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
718 "Expected dst type of rank 1 or 2");
719
720 unsigned rank = dstType.getRank();
721 unsigned dstRows = dstType.getShape()[0];
722 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
723
724 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
725 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
726 rewriter.getZeroAttr(dstType));
727 bool isInt = isa<IntegerType>(dstType.getElementType());
728 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
729 llvm::SmallVector<Value> extractedCols;
730 extractedCols.reserve(dstColumns);
731 for (unsigned r = 0; r < dstRows; ++r) {
732 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
733 for (unsigned c = 0; c < dstColumns; ++c) {
734 // Extract each respective row and column of the LHS and RHS once to
735 // avoid having duplicate SSA values pointing to the same rows/columns.
736 if (r == 0) {
737 Value colRhs =
738 rank == 1
739 ? rhs
740 : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
741 extractedCols.push_back(colRhs);
742 }
743 Value extractedColRhs = extractedCols[c];
744 Value product =
745 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter, fmf);
746 Value sum = vector::ReductionOp::create(rewriter, op.getLoc(),
747 vector::CombiningKind::ADD,
748 product, op.getFastmath());
749
752 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
753 }
754 }
755 if (auto acc = op.getAcc())
756 res = createAdd(op.getLoc(), res, acc, isInt, rewriter, fmf);
757 return res;
758}
759
760/// Lower vector.contract with all size one reduction dimensions to
761/// elementwise ops when possible.
763 : public MaskableOpRewritePattern<vector::ContractionOp> {
764 using MaskableOpRewritePattern::MaskableOpRewritePattern;
766 std::function<LogicalResult(vector::ContractionOp op)>;
767 static LogicalResult defaultFilter(vector::ContractionOp op) {
768 return success();
769 }
771 vector::VectorContractLowering vectorContractLowering,
772 MLIRContext *context, PatternBenefit benefit = 1,
773 const FilterConstraintType &constraint = defaultFilter)
774 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
775 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
776
777 FailureOr<Value>
778 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
779 MaskingOpInterface maskOp,
780 PatternRewriter &rewriter) const override {
781 // TODO: Support vector.mask.
782 if (maskOp)
783 return failure();
784
785 if (failed(filter(contractOp)))
786 return failure();
787
788 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
789 return failure();
790
791 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
792 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
793 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
794 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
795 SmallVector<int64_t> lhsReductionDims =
796 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
797 SmallVector<int64_t> rhsReductionDims =
798 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
799 // All the reduction dimensions must be a size 1.
800 for (int64_t dim : lhsReductionDims) {
801 if (lhsShape[dim] != 1)
802 return failure();
803 }
804 for (int64_t dim : rhsReductionDims) {
805 if (rhsShape[dim] != 1)
806 return failure();
807 }
808 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
809 unsigned numParallelDims = accMap.getNumResults();
810 unsigned numLhsDimToBroadcast =
811 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
812 unsigned numRhsDimToBroadcast =
813 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
814 SmallVector<int64_t> lhsDims;
815 SmallVector<int64_t> lhsTranspose;
816 SmallVector<int64_t> rhsDims;
817 SmallVector<int64_t> rhsTranspose;
818 for (int64_t dim : lhsReductionDims)
819 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
820 for (int64_t dim : rhsReductionDims)
821 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
822 // Loop through the parallel dimensions to calculate the dimensions to
823 // broadcast and to permute in order to extract only parallel dimensions.
824 for (unsigned i = 0; i < numParallelDims; i++) {
825 std::optional<unsigned> lhsDim =
826 getDimPosition(lhsMap, accMap.getDimPosition(i));
827 if (lhsDim) {
828 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
829 } else {
830 // If the parallel dimension doesn't exist we will have to broadcast it.
831 lhsDims.push_back(
832 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
833 lhsTranspose.push_back(lhsDims.size() - 1);
834 }
835 std::optional<unsigned> rhsDim =
836 getDimPosition(rhsMap, accMap.getDimPosition(i));
837 if (rhsDim) {
838 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
839 } else {
840 // If the parallel dimension doesn't exist we will have to broadcast it.
841 rhsDims.push_back(
842 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
843 rhsTranspose.push_back(rhsDims.size() - 1);
844 }
845 }
846 Value newLhs = contractOp.getLhs();
847 Value newRhs = contractOp.getRhs();
848 Location loc = contractOp.getLoc();
849 if (!lhsDims.empty()) {
850 lhsDims.append(lhsShape.begin(), lhsShape.end());
851 auto expandedType =
852 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
853 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
854 }
855 if (!rhsDims.empty()) {
856 rhsDims.append(rhsShape.begin(), rhsShape.end());
857 auto expandedType =
858 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
859 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
860 }
861 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
862 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
863 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
864 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
865 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
866 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
867 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
868 std::optional<Value> result =
869 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
870 contractOp.getKind(), rewriter, isInt,
871 /*mask=*/Value(), contractOp.getFastmathAttr());
872 if (result)
873 return *result;
874
875 return failure();
876 }
877
878private:
879 /// Options to control the vector patterns.
880 vector::VectorContractLowering vectorContractLowering;
882};
883
884/// Progressive lowering of ContractionOp.
885/// One:
886/// %x = vector.contract with at least one free/batch dimension
887/// is replaced by:
888/// %a = vector.contract with one less free/batch dimension
889/// %b = vector.contract with one less free/batch dimension
890/// ..
891/// %x = combine %a %b ..
892/// until a pure contraction is reached (no free/batch dimensions),
893/// which is replaced by a dot-product.
894///
895/// This only kicks in when either vectorContractLoweringOption is set
896/// to DOT or when other contraction patterns fail.
897//
898// TODO: break down into transpose/reshape/cast ops
899// when they become available to avoid code dup
900// TODO: investigate lowering order impact on performance
901FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
902 vector::ContractionOp op, MaskingOpInterface maskOp,
903 PatternRewriter &rewriter) const {
904 if (failed(filter(op)))
905 return failure();
906
907 // TODO: support mixed mode contract lowering.
908 if (op.getLhsType().getElementType() !=
909 getElementTypeOrSelf(op.getAccType()) ||
910 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
911 return failure();
912
913 // TODO: the code below assumes the default contraction, make sure it supports
914 // other kinds before enabling this lowering.
915 if (op.getKind() != vector::CombiningKind::ADD) {
916 return rewriter.notifyMatchFailure(
917 op, "contractions other than 'add' not supported");
918 }
919
920 // TODO: implement benefits, cost models.
921 MLIRContext *ctx = op.getContext();
922
923 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
924 FailureOr<Value> newVal1 =
925 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
926 if (!failed(newVal1))
927 return newVal1;
928
929 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
930 FailureOr<Value> newVal2 =
931 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
932 if (!failed(newVal2))
933 return newVal2;
934
935 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
936 FailureOr<Value> newVal4 =
937 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
938 if (!failed(newVal4))
939 return newVal4;
940
941 // Vector mask setup.
942
943 Value mask;
944 if (maskOp)
945 mask = maskOp.getMask();
946 // Find first batch dimension in LHS/RHS, and lower when found.
947 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
948 if (!batchDimMap.empty()) {
949 int64_t lhsIndex = batchDimMap[0].first;
950 int64_t rhsIndex = batchDimMap[0].second;
951 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
952 if (failed(newOp))
953 return failure();
954 return newOp;
955 }
956
957 // Collect contracting dimensions.
958 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
959 op.getContractingDimMap();
960 DenseSet<int64_t> lhsContractingDimSet;
961 DenseSet<int64_t> rhsContractingDimSet;
962 for (auto &dimPair : contractingDimMap) {
963 lhsContractingDimSet.insert(dimPair.first);
964 rhsContractingDimSet.insert(dimPair.second);
965 }
966
967 // Find first free dimension in LHS, and lower when found.
968 VectorType lhsType = op.getLhsType();
969 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
970 if (lhsContractingDimSet.count(lhsIndex) == 0) {
971 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
972 if (failed(newOp))
973 return failure();
974 return newOp;
975 }
976 }
977
978 // Find first free dimension in RHS, and lower when found.
979 VectorType rhsType = op.getRhsType();
980 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
981 if (rhsContractingDimSet.count(rhsIndex) == 0) {
982 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
983 if (failed(newOp))
984 return failure();
985 return newOp;
986 }
987 }
988
989 // Lower the first remaining reduction dimension.
990 if (!contractingDimMap.empty()) {
991 auto newOp = lowerReduction(rewriter, op, mask);
992 if (failed(newOp))
993 return failure();
994 return newOp;
995 }
996
997 return failure();
998}
999
1000// Lower one parallel dimension.
1001// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
1002// TODO: consider reusing existing contract unrolling
1003FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
1004 vector::ContractionOp op,
1005 int64_t lhsIndex,
1006 int64_t rhsIndex,
1007 Value mask) const {
1008 VectorType lhsType = op.getLhsType();
1009 VectorType rhsType = op.getRhsType();
1010 VectorType resType = cast<VectorType>(op.getResultType());
1011 // Find the iterator type index and result index.
1012 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1013 int64_t iterIndex = -1;
1014 int64_t dimSize = -1;
1015 if (lhsIndex >= 0) {
1016 iterIndex = iMap[0].getDimPosition(lhsIndex);
1017 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
1018 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1019 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
1020 << " to map to the same dimension";
1021 });
1022 if (lhsType.getScalableDims()[lhsIndex])
1023 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1024 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1025 << ") is not supported yet";
1026 });
1027 dimSize = lhsType.getDimSize(lhsIndex);
1028 } else if (rhsIndex >= 0) {
1029 iterIndex = iMap[1].getDimPosition(rhsIndex);
1030 if (rhsType.getScalableDims()[rhsIndex])
1031 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1032 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1033 << ") is not supported yet";
1034 });
1035 dimSize = rhsType.getDimSize(rhsIndex);
1036 }
1037 if (iterIndex < 0)
1038 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1039 diag << "expected either lhsIndex=" << lhsIndex
1040 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1041 });
1042 // value_or(-1) means that we tolerate a dimension not appearing
1043 // in the result map. That can't happen for actual parallel iterators, but
1044 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1045 // lowerParallel also for the case of unit-size reduction dims appearing only
1046 // on one of LHS or RHS, not both. At the moment, such cases are created by
1047 // CastAwayContractionLeadingOneDim, so we need to either support that or
1048 // modify that pattern.
1049 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1050 if (resIndex == -1 && dimSize != 1)
1051 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1052 diag << "expected the dimension for iterIndex=" << iterIndex
1053 << " to either appear in the result map, or to be a unit dimension";
1054 });
1055
1056 // Construct new iterator types and affine map array attribute.
1057 std::array<AffineMap, 3> lowIndexingMaps = {
1058 adjustMap(iMap[0], iterIndex, rewriter),
1059 adjustMap(iMap[1], iterIndex, rewriter),
1060 adjustMap(iMap[2], iterIndex, rewriter)};
1061 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1062 auto lowIter =
1063 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1064 // Unroll into a series of lower dimensional vector.contract ops.
1065 Location loc = op.getLoc();
1066 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1067 rewriter.getZeroAttr(resType));
1068
1069 for (int64_t d = 0; d < dimSize; ++d) {
1070 auto lhs = reshapeLoad(loc, op.getLhs(), lhsIndex, d, rewriter);
1071 auto rhs = reshapeLoad(loc, op.getRhs(), rhsIndex, d, rewriter);
1072 auto acc = reshapeLoad(loc, op.getAcc(), resIndex, d, rewriter);
1073
1074 Value lowMask;
1075 if (mask)
1076 lowMask = reshapeLoad(loc, mask, iterIndex, d, rewriter);
1077
1078 Operation *lowContract =
1079 vector::ContractionOp::create(rewriter, loc, lhs, rhs, acc, lowAffine,
1080 lowIter, op.getKind(), op.getFastmath());
1081 lowContract = maskOperation(rewriter, lowContract, lowMask);
1082 result = reshapeStore(loc, lowContract->getResult(0), result, resIndex, d,
1083 rewriter);
1084 }
1085 return result;
1086}
1087
1088// Lower one reduction dimension.
1089FailureOr<Value> ContractionOpLowering::lowerReduction(
1090 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1091 auto loc = op.getLoc();
1092 VectorType lhsType = op.getLhsType();
1093 VectorType rhsType = op.getRhsType();
1094 Type resType = op.getResultType();
1095 if (isa<VectorType>(resType))
1096 return rewriter.notifyMatchFailure(op,
1097 "did not expect a VectorType result");
1098 bool isInt = isa<IntegerType>(resType);
1099 // Use iterator index 0.
1100 int64_t iterIndex = 0;
1101 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1102 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1103 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1104 if (!lookupLhs.has_value())
1105 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1106 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1107 });
1108 if (!lookupRhs.has_value())
1109 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1110 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1111 });
1112 int64_t lhsIndex = *lookupLhs;
1113 int64_t rhsIndex = *lookupRhs;
1114 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1115 if (dimSize != rhsType.getDimSize(rhsIndex))
1116 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1117 diag << "expect LHS dimension " << lhsIndex
1118 << " to have the same size as RHS dimension " << rhsIndex;
1119 });
1120 // Base case.
1121 if (lhsType.getRank() == 1) {
1122 if (rhsType.getRank() != 1)
1123 return rewriter.notifyMatchFailure(
1124 op, "When LHS has rank 1, expected also RHS to have rank 1");
1125 arith::FastMathFlagsAttr fmf = op.getFastmathAttr();
1126 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter, fmf);
1127 auto kind = vector::CombiningKind::ADD;
1128
1129 Value acc = op.getAcc();
1130 Operation *reductionOp =
1131 acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc,
1132 op.getFastmath())
1133 : vector::ReductionOp::create(rewriter, loc, kind, m,
1134 op.getFastmath());
1135 return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1136 }
1137 // Construct new iterator types and affine map array attribute.
1138 std::array<AffineMap, 3> lowIndexingMaps = {
1139 adjustMap(iMap[0], iterIndex, rewriter),
1140 adjustMap(iMap[1], iterIndex, rewriter),
1141 adjustMap(iMap[2], iterIndex, rewriter)};
1142 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1143 auto lowIter =
1144 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1145 // Unroll into a series of lower dimensional vector.contract ops.
1146 // By feeding the initial accumulator into the first contraction,
1147 // and the result of each contraction into the next, eventually
1148 // the sum of all reductions is computed.
1149 Value result = op.getAcc();
1150 for (int64_t d = 0; d < dimSize; ++d) {
1151 auto lhs = reshapeLoad(loc, op.getLhs(), lhsIndex, d, rewriter);
1152 auto rhs = reshapeLoad(loc, op.getRhs(), rhsIndex, d, rewriter);
1153 Value newMask;
1154 if (mask)
1155 newMask = reshapeLoad(loc, mask, iterIndex, d, rewriter);
1156
1157 Operation *newContract = vector::ContractionOp::create(
1158 rewriter, loc, lhs, rhs, result, lowAffine, lowIter, op.getKind(),
1159 op.getFastmath());
1160 result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1161 }
1162 return result;
1163}
1164
1165/// Progressive lowering of OuterProductOp.
1166/// One:
1167/// %x = vector.outerproduct %lhs, %rhs, %acc
1168/// is replaced by:
1169/// %z = zero-result
1170/// %0 = vector.extract %lhs[0]
1171/// %1 = vector.broadcast %0
1172/// %2 = vector.extract %acc[0]
1173/// %3 = vector.fma %1, %rhs, %2
1174/// %4 = vector.insert %3, %z[0]
1175/// ..
1176/// %x = vector.insert %.., %..[N-1]
1177///
1178class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1179public:
1180 using Base::Base;
1181
1182 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1183 PatternRewriter &rewriter) const override {
1184 VectorType resType = op.getResultVectorType();
1185 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1186 return failure();
1187
1188 auto loc = op.getLoc();
1189
1190 VectorType lhsType = op.getOperandVectorTypeLHS();
1191 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1192 Type eltType = resType.getElementType();
1193 bool isInt = isa<IntegerType, IndexType>(eltType);
1194 Value acc = op.getAcc();
1195 vector::CombiningKind kind = op.getKind();
1196
1197 // Vector mask setup.
1198 OpBuilder::InsertionGuard guard(rewriter);
1199 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1200 Operation *rootOp;
1201 Value mask;
1202 if (maskableOp.isMasked()) {
1203 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1204 rootOp = maskableOp.getMaskingOp();
1205 mask = maskableOp.getMaskingOp().getMask();
1206 } else {
1207 rootOp = op;
1208 }
1209
1210 if (!rhsType) {
1211 // Special case: AXPY operation.
1212 Value b =
1213 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1214 std::optional<Value> mult = createContractArithOp(
1215 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1216 if (!mult.has_value())
1217 return failure();
1218 rewriter.replaceOp(rootOp, *mult);
1219 return success();
1220 }
1221
1222 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1223 rewriter.getZeroAttr(resType));
1224 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1225 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1226 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1227 Value r = nullptr;
1228 if (acc)
1229 r = vector::ExtractOp::create(rewriter, loc, acc, d);
1230 Value extrMask;
1231 if (mask)
1232 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1233
1234 std::optional<Value> m = createContractArithOp(
1235 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1236 if (!m.has_value())
1237 return failure();
1238 result = vector::InsertOp::create(rewriter, loc, *m, result, d);
1239 }
1240
1241 rewriter.replaceOp(rootOp, result);
1242 return success();
1243 }
1244};
1245
1246} // namespace
1247
1249 RewritePatternSet &patterns,
1250 VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1251 bool disableOuterProductLowering) {
1252 if (!disableOuterProductLowering)
1253 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1254 patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1255 vectorContractLoweringOption, patterns.getContext(), benefit);
1256}
1257
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
auto load
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static Value reshapeStore(Location loc, Value val, Value result, int64_t index, int64_t pos, PatternRewriter &rewriter)
Inserts val into result at position pos along dimension index.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter, arith::FastMathFlagsAttr fmf={})
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, int64_t index, int64_t pos, PatternRewriter &rewriter)
Returns val with the dimension at position index dropped by indexing that dimension with pos.
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value(), arith::FastMathFlagsAttr fmf={})
Helper to create arithmetic operation associated with a kind of contraction.
static std::string diag(const llvm::Value &value)
#define mul(a, b)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:323
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:432
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:748
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0