MLIR 22.0.0git
LowerVectorContract.cpp
Go to the documentation of this file.
1//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.contract' operation.
11//
12//===----------------------------------------------------------------------===//
13
22#include "mlir/IR/Location.h"
25
26#define DEBUG_TYPE "vector-contract-lowering"
27
28using namespace mlir;
29using namespace mlir::vector;
30
31//===----------------------------------------------------------------------===//
32// Helper functions
33//===----------------------------------------------------------------------===//
34// Helper to find an index in an affine map.
35static std::optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
36 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
37 int64_t idx = map.getDimPosition(i);
38 if (idx == index)
39 return i;
40 }
41 return std::nullopt;
42}
43
44// Helper to construct iterator types with one index removed.
46 int64_t index) {
48 for (const auto &it : llvm::enumerate(iteratorTypes)) {
49 int64_t idx = it.index();
50 if (idx == index)
51 continue;
52 results.push_back(it.value());
53 }
54 return results;
55}
56
57// Helper to construct an affine map with one index removed.
59 PatternRewriter &rewriter) {
60 auto *ctx = rewriter.getContext();
62 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
63 int64_t idx = map.getDimPosition(i);
64 if (idx == index)
65 continue;
66 // Re-insert remaining indices, but renamed when occurring
67 // after the removed index.
68 auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx);
69 results.push_back(targetExpr);
70 }
71 return AffineMap::get(map.getNumDims() - 1, 0, results, ctx);
72}
73
74// Helper method to possibly drop a dimension in a load.
75// TODO
76static Value reshapeLoad(Location loc, Value val, VectorType type,
78 PatternRewriter &rewriter) {
79 if (index == -1)
80 return val;
81
82 // At extraction dimension?
83 if (index == 0)
84 return vector::ExtractOp::create(rewriter, loc, val, pos);
85
86 // Unroll leading dimensions.
87 VectorType vType = VectorType::Builder(type).dropDim(0);
88 VectorType resType = VectorType::Builder(type).dropDim(index);
89 Value result = arith::ConstantOp::create(rewriter, loc, resType,
90 rewriter.getZeroAttr(resType));
91 for (int64_t d = 0, e = resType.getDimSize(0); d < e; d++) {
92 Value ext = vector::ExtractOp::create(rewriter, loc, val, d);
93 Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter);
94 result = vector::InsertOp::create(rewriter, loc, load, result, d);
95 }
96 return result;
97}
98
99// Helper method to possibly drop a dimension in a store.
100// TODO
102 VectorType type, int64_t index, int64_t pos,
103 PatternRewriter &rewriter) {
104 // Unmodified?
105 if (index == -1)
106 return val;
107 // At insertion dimension?
108 if (index == 0)
109 return vector::InsertOp::create(rewriter, loc, val, result, pos);
110
111 // Unroll leading dimensions.
112 VectorType vType = VectorType::Builder(type).dropDim(0);
113 for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) {
114 Value ext = vector::ExtractOp::create(rewriter, loc, result, d);
115 Value ins = vector::ExtractOp::create(rewriter, loc, val, d);
116 Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter);
117 result = vector::InsertOp::create(rewriter, loc, sto, result, d);
118 }
119 return result;
120}
121
122/// Helper to create arithmetic operation associated with a kind of contraction.
123static std::optional<Value>
125 vector::CombiningKind kind, PatternRewriter &rewriter,
126 bool isInt, Value mask = Value()) {
127 using vector::CombiningKind;
128 Value mul;
129
130 if (isInt) {
131 if (kind == CombiningKind::MINNUMF || kind == CombiningKind::MAXNUMF ||
132 kind == CombiningKind::MINIMUMF || kind == CombiningKind::MAXIMUMF)
133 // Only valid for floating point types.
134 return std::nullopt;
135 mul = arith::MulIOp::create(rewriter, loc, x, y);
136 } else {
137 // Float case.
138 if (kind == CombiningKind::AND || kind == CombiningKind::MINUI ||
139 kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI ||
140 kind == CombiningKind::MAXSI || kind == CombiningKind::OR ||
141 kind == CombiningKind::XOR)
142 // Only valid for integer types.
143 return std::nullopt;
144 // Special case for fused multiply-add.
145 if (acc && isa<VectorType>(acc.getType()) && kind == CombiningKind::ADD) {
146 Value fma = vector::FMAOp::create(rewriter, loc, x, y, acc);
147 if (mask)
148 // The fma op doesn't need explicit masking. However, fma ops used in
149 // reductions must preserve previous 'acc' values for masked-out lanes.
150 fma = selectPassthru(rewriter, mask, fma, acc);
151 return fma;
152 }
153 mul = arith::MulFOp::create(rewriter, loc, x, y);
154 }
155
156 if (!acc)
157 return std::optional<Value>(mul);
158
159 return makeArithReduction(rewriter, loc, kind, mul, acc,
160 /*fastmath=*/nullptr, mask);
161}
162
163/// Return the positions of the reductions in the given map.
165 ArrayAttr iteratorTypes) {
166 SmallVector<int64_t> dimsIdx;
167 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
168 if (isReductionIterator(iteratorTypes[map.getDimPosition(i)]))
169 dimsIdx.push_back(i);
170 }
171 return dimsIdx;
172}
173
174/// Look for a given dimension in an affine map and return its position. Return
175/// std::nullopt if the dimension is not in the map results.
176static std::optional<unsigned> getDimPosition(AffineMap map, unsigned dim) {
177 for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
178 if (map.getDimPosition(i) == dim)
179 return i;
180 }
181 return std::nullopt;
182}
183
184/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using
185/// operands `x` and `y`.
186static Value createAdd(Location loc, Value x, Value y, bool isInt,
187 PatternRewriter &rewriter) {
188 if (isInt)
189 return arith::AddIOp::create(rewriter, loc, x, y);
190 return arith::AddFOp::create(rewriter, loc, x, y);
191}
192
193/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using
194/// operands `x and `y`.
195static Value createMul(Location loc, Value x, Value y, bool isInt,
196 PatternRewriter &rewriter) {
197 if (isInt)
198 return arith::MulIOp::create(rewriter, loc, x, y);
199 return arith::MulFOp::create(rewriter, loc, x, y);
200}
201
202namespace {
203
204/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
205/// semantics to a reduction_size-unrolled sequence:
206/// ```
207/// %at = vector.transpose %a, [1, 0]
208/// %bRow0 = vector.extract %b[0]
209/// %atRow0 = vector.extract %at[0]
210/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
211/// ...
212/// %bRowK = vector.extract %b[K]
213/// %atRowK = vector.extract %at[K]
214/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
215/// ```
216///
217/// This only kicks in when vectorContractLowering is set to OuterProduct and
218/// the vector.contract op is a row-major matrix multiply.
219class ContractionOpToOuterProductOpLowering
220 : public MaskableOpRewritePattern<vector::ContractionOp> {
221public:
222 using MaskableOpRewritePattern::MaskableOpRewritePattern;
223
224 using FilterConstraintType =
225 std::function<LogicalResult(vector::ContractionOp op)>;
226
227 static LogicalResult defaultFilter(vector::ContractionOp op) {
228 return success();
229 }
230
231 ContractionOpToOuterProductOpLowering(
232 vector::VectorContractLowering vectorContractLowering,
233 MLIRContext *context, PatternBenefit benefit = 1,
234 FilterConstraintType constraint = defaultFilter)
235 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
236 vectorContractLowering(vectorContractLowering),
237 filter(std::move(constraint)) {}
238
239 FailureOr<Value>
240 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
241 PatternRewriter &rewriter) const override;
242
243private:
244 /// Options to control the vector patterns.
245 vector::VectorContractLowering vectorContractLowering;
246 FilterConstraintType filter;
247};
248
249/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
250/// semantics to an output-size-unrolled sequence:
251/// ```
252/// %out = arith.constant ... : vector<MxNxelt_type>
253/// %bt = vector.transpose %b, [1, 0]
254/// %aRow0 = vector.extract %a[0]
255/// %btRow0 = vector.extract %bt[0]
256/// %c00 = vector.reduce %atRow0, %bRow0
257/// %out00 = vector.insert %c00, %out[0, 0]
258/// ...
259/// %aRowLast = vector.extract %at[M-1]
260/// %btRowLast = vector.extract %b[N-1]
261/// %cLastLast = vector.reduce %atRowLast, %bRowLast
262/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1]
263/// ```
264///
265/// This only kicks in when VectorTransformsOptions is set to Dot and
266/// the vector.contract op is a row-major matmul or matvec.
267class ContractionOpToDotLowering
268 : public MaskableOpRewritePattern<vector::ContractionOp> {
269public:
270 using MaskableOpRewritePattern::MaskableOpRewritePattern;
271
272 using FilterConstraintType =
273 std::function<LogicalResult(vector::ContractionOp op)>;
274
275 static LogicalResult defaultFilter(vector::ContractionOp op) {
276 return success();
277 }
278
279 ContractionOpToDotLowering(
280 vector::VectorContractLowering vectorContractLowering,
281 MLIRContext *context, PatternBenefit benefit = 1,
282 const FilterConstraintType &constraint = defaultFilter)
283 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
284 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
285
286 FailureOr<Value>
287 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
288 PatternRewriter &rewriter) const override;
289
290private:
291 /// Options to control the vector patterns.
292 vector::VectorContractLowering vectorContractLowering;
293 FilterConstraintType filter;
294};
295
296/// Progressive lowering of ContractionOp.
297///
298/// One:
299/// %x = vector.contract with at least one free/batch dimension
300/// is replaced by:
301/// %a = vector.contract with one less free/batch dimension
302/// %b = vector.contract with one less free/batch dimension
303/// ..
304/// %x = combine %a %b ..
305/// until a pure contraction is reached (no free/batch dimensions),
306/// which is replaced by a dot-product.
307///
308/// This only kicks in when either VectorTransformsOptions is set
309/// to Dot or when other contraction patterns fail.
310class ContractionOpLowering
311 : public MaskableOpRewritePattern<vector::ContractionOp> {
312public:
313 using MaskableOpRewritePattern::MaskableOpRewritePattern;
314 using FilterConstraintType =
315 std::function<LogicalResult(vector::ContractionOp op)>;
316
317 static LogicalResult defaultFilter(vector::ContractionOp op) {
318 return success();
319 }
320
321 ContractionOpLowering(
322 vector::VectorContractLowering vectorContractLoweringOption,
323 MLIRContext *context, PatternBenefit benefit = 1,
324 FilterConstraintType constraint = defaultFilter)
325 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
326 vectorContractLoweringOption(vectorContractLoweringOption),
327 filter(std::move(constraint)) {}
328
329 FailureOr<Value>
330 matchAndRewriteMaskableOp(vector::ContractionOp op, MaskingOpInterface maskOp,
331 PatternRewriter &rewriter) const override;
332
333private:
334 /// Options to control the vector patterns.
335 vector::VectorContractLowering vectorContractLoweringOption;
336 FilterConstraintType filter;
337 // Lower one parallel dimension.
338 FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
339 vector::ContractionOp op, int64_t lhsIndex,
340 int64_t rhsIndex, Value mask) const;
341 // Lower one reduction dimension.
342 FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
343 vector::ContractionOp op, Value mask) const;
344};
345
346/// Generate a vector implementation for matmat, matvec and tmatvec.
347/// This unrolls outer-products along the reduction dimension.
348struct UnrolledOuterProductGenerator
349 : public StructuredGenerator<vector::ContractionOp, vector::IteratorType> {
350 UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
351 : StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
352 kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
353 res(op.getAcc()), lhsType(op.getLhsType()) {
354 auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
355 if (maskableOp.isMasked())
356 mask = maskableOp.getMaskingOp().getMask();
357 }
358
359 Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
360 if (!v)
361 return v;
362 return vector::TransposeOp::create(rewriter, loc, v, perm);
363 }
364
365 Value promote(Value v, Type dstElementType) {
366 Type elementType = v.getType();
367 auto vecType = dyn_cast<VectorType>(elementType);
368 if (vecType)
369 elementType = vecType.getElementType();
370 if (elementType == dstElementType)
371 return v;
372 Type promotedType = dstElementType;
373 if (vecType)
374 promotedType = vecType.clone(promotedType);
375 if (isa<FloatType>(dstElementType))
376 return arith::ExtFOp::create(rewriter, loc, promotedType, v);
377 return arith::ExtSIOp::create(rewriter, loc, promotedType, v);
378 }
379
380 FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
381 VectorType lhsType, int reductionSize,
382 std::optional<Value> maybeMask = std::nullopt) {
383 // Incremental support for masking.
384 if (mask && !maybeMask.has_value())
385 return failure();
386
387 Type resElementType = cast<VectorType>(res.getType()).getElementType();
388 for (int64_t k = 0; k < reductionSize; ++k) {
389 Value extractA = vector::ExtractOp::create(rewriter, loc, lhs, k);
390 Value extractB = vector::ExtractOp::create(rewriter, loc, rhs, k);
391 extractA = promote(extractA, resElementType);
392 extractB = promote(extractB, resElementType);
393 Value extractMask;
394 if (maybeMask.has_value() && maybeMask.value())
395 extractMask =
396 vector::ExtractOp::create(rewriter, loc, maybeMask.value(), k);
397
398 Operation *outerProdOp = vector::OuterProductOp::create(
399 rewriter, loc, res.getType(), extractA, extractB, res, kind);
400 res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
401 }
402 return res;
403 }
404
405 /// Helper function for `matmat`, `matvec`, `tmatvec`. Returns the size of
406 /// dimension `reductionDim`. If the dimension is a scalable dimension,
407 /// returns "nullopt".
408 std::optional<int64_t> getReductionSize(VectorType vecType,
409 int64_t reductionDim) {
410 // Cannot unroll scalable dimension.
411 if (vecType.getScalableDims()[reductionDim])
412 return std::nullopt;
413 int64_t reductionSize = vecType.getDimSize(reductionDim);
414 assert(reductionSize > 0 &&
415 "Reduction dim must be a known static size to allow unrolling");
416 return reductionSize;
417 }
418
419 /// Two outer parallel, one inner reduction (matmat flavor).
420 FailureOr<Value> matmat() {
421 if (!iters({Par(), Par(), Red()}))
422 return failure();
423 // Set up the parallel/reduction structure in the right form.
424 AffineExpr m, n, k;
425 bindDims(rewriter.getContext(), m, n, k);
426
427 // Classical row-major matmul: Just permute the lhs.
428 if (layout({{m, k}, {k, n}, {m, n}})) {
429 if (auto reductionSize = getReductionSize(lhsType, 1)) {
430 // Note: `t` creates new IR. It must be nested within this `if` check
431 // so that no IR is created when then pattern returns "failure".
432 Value tLhs = t(lhs);
433 Value tMask = t(mask, {2, 0, 1});
434 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
435 }
436 }
437 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
438 if (layout({{m, k}, {n, k}, {m, n}})) {
439 if (auto reductionSize = getReductionSize(lhsType, 1)) {
440 Value tLhs = t(lhs);
441 Value tRhs = t(rhs);
442 Value tMask = t(mask, {2, 0, 1});
443 return outerProd(tLhs, tRhs, res, lhsType, *reductionSize, tMask);
444 }
445 }
446 // No need to permute anything.
447 if (layout({{k, m}, {k, n}, {m, n}})) {
448 if (auto reductionSize = getReductionSize(lhsType, 0)) {
449 Value tMask = t(mask, {2, 0, 1});
450 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
451 }
452 }
453 // Just permute the rhs.
454 if (layout({{k, m}, {n, k}, {m, n}})) {
455 if (auto reductionSize = getReductionSize(lhsType, 0)) {
456 Value tRhs = t(rhs);
457 Value tMask = t(mask, {2, 0, 1});
458 return outerProd(lhs, tRhs, res, lhsType, *reductionSize, tMask);
459 }
460 }
461 // Transposed output: swap RHS and LHS.
462 // Classical row-major matmul: permute the lhs.
463 if (layout({{m, k}, {k, n}, {n, m}})) {
464 if (auto reductionSize = getReductionSize(lhsType, 1)) {
465 Value tLhs = t(lhs);
466 Value tMask = t(mask, {2, 0, 1});
467 return outerProd(rhs, tLhs, res, lhsType, *reductionSize, tMask);
468 }
469 }
470 // TODO: may be better to fail and use some vector<k> -> scalar reduction.
471 if (layout({{m, k}, {n, k}, {n, m}})) {
472 if (auto reductionSize = getReductionSize(lhsType, 1)) {
473 Value tRhs = t(rhs);
474 Value tLhs = t(lhs);
475 Value tMask = t(mask, {2, 0, 1});
476 return outerProd(tRhs, tLhs, res, lhsType, *reductionSize, tMask);
477 }
478 }
479 if (layout({{k, m}, {k, n}, {n, m}})) {
480 if (auto reductionSize = getReductionSize(lhsType, 0)) {
481 Value tMask = t(mask, {2, 0, 1});
482 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
483 }
484 }
485 if (layout({{k, m}, {n, k}, {n, m}})) {
486 if (auto reductionSize = getReductionSize(lhsType, 0)) {
487 Value tRhs = t(rhs);
488 Value tMask = t(mask, {2, 0, 1});
489 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
490 }
491 }
492 return failure();
493 }
494
495 //
496 // One outer parallel, one inner reduction (matvec flavor).
497 // Mask needs to be transposed everywhere to turn the reduction dimension
498 // outermost as required by outerproduct.
499 //
500 FailureOr<Value> matvec() {
501 if (!iters({Par(), Red()}))
502 return failure();
503 AffineExpr m, k;
504 bindDims(rewriter.getContext(), m, k);
505
506 // Case mat-vec: transpose.
507 if (layout({{m, k}, {k}, {m}})) {
508 if (auto reductionSize = getReductionSize(lhsType, 1)) {
509 Value tLhs = t(lhs);
510 Value tMask = t(mask);
511 return outerProd(tLhs, rhs, res, lhsType, *reductionSize, tMask);
512 }
513 }
514 // Case mat-trans-vec: ready to go.
515 if (layout({{k, m}, {k}, {m}})) {
516 if (auto reductionSize = getReductionSize(lhsType, 0)) {
517 Value tMask = t(mask);
518 return outerProd(lhs, rhs, res, lhsType, *reductionSize, tMask);
519 }
520 }
521 // Case vec-mat: swap and transpose.
522 if (layout({{k}, {m, k}, {m}})) {
523 if (auto reductionSize = getReductionSize(lhsType, 0)) {
524 Value tRhs = t(rhs);
525 Value tMask = t(mask);
526 return outerProd(tRhs, lhs, res, lhsType, *reductionSize, tMask);
527 }
528 }
529 // Case vec-mat-trans: swap and ready to go.
530 if (layout({{k}, {k, m}, {m}})) {
531 if (auto reductionSize = getReductionSize(lhsType, 0)) {
532 Value tMask = t(mask);
533 return outerProd(rhs, lhs, res, lhsType, *reductionSize, tMask);
534 }
535 }
536 return failure();
537 }
538
539 //
540 // One outer reduction, one inner parallel (tmatvec flavor).
541 // Mask already has the shape of the outer product.
542 //
543 FailureOr<Value> tmatvec() {
544 if (!iters({Red(), Par()}))
545 return failure();
546 AffineExpr k, m;
547 bindDims(rewriter.getContext(), k, m);
548
549 // Case mat-vec: transpose.
550 if (layout({{m, k}, {k}, {m}}))
551 if (auto reductionSize = getReductionSize(lhsType, 1))
552 return outerProd(t(lhs), rhs, res, lhsType, *reductionSize, mask);
553 // Case mat-trans-vec: ready to go.
554 if (layout({{k, m}, {k}, {m}}))
555 if (auto reductionSize = getReductionSize(lhsType, 0))
556 return outerProd(lhs, rhs, res, lhsType, *reductionSize, mask);
557 // Case vec-mat: swap and transpose.
558 if (layout({{k}, {m, k}, {m}}))
559 if (auto reductionSize = getReductionSize(lhsType, 0))
560 return outerProd(t(rhs), lhs, res, lhsType, *reductionSize, mask);
561 // Case vec-mat-trans: swap and ready to go.
562 if (layout({{k}, {k, m}, {m}}))
563 if (auto reductionSize = getReductionSize(lhsType, 0))
564 return outerProd(rhs, lhs, res, lhsType, *reductionSize, mask);
565 return failure();
566 }
567
568private:
569 vector::CombiningKind kind;
570 Value lhs, rhs, res, mask;
571 VectorType lhsType;
572};
573
574/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul
575/// semantics to a reduction_size-unrolled sequence:
576/// ```
577/// %at = vector.transpose %a, [1, 0]
578/// %bRow0 = vector.extract %b[0]
579/// %atRow0 = vector.extract %at[0]
580/// %c0 = vector.outerproduct %atRow0, %bRow0, %c
581/// ...
582/// %bRowK = vector.extract %b[K]
583/// %atRowK = vector.extract %at[K]
584/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1
585/// ```
586///
587/// This only kicks in when vectorContractLowering is set to OuterProduct but
588/// otherwise supports any layout permutation of the matrix-multiply.
589FailureOr<Value>
590ContractionOpToOuterProductOpLowering::matchAndRewriteMaskableOp(
591 vector::ContractionOp op, MaskingOpInterface maskOp,
592 PatternRewriter &rewriter) const {
593 if (vectorContractLowering != vector::VectorContractLowering::OuterProduct)
594 return failure();
595
596 if (failed(filter(op)))
597 return failure();
598
599 UnrolledOuterProductGenerator e(rewriter, op);
600 FailureOr<Value> matmatRes = e.matmat();
601 if (succeeded(matmatRes)) {
602 return matmatRes;
603 }
604 FailureOr<Value> matvecRes = e.matvec();
605 if (succeeded(matvecRes)) {
606 return matvecRes;
607 }
608
609 FailureOr<Value> tmatvecRes = e.tmatvec();
610 return tmatvecRes;
611}
612
613FailureOr<Value> ContractionOpToDotLowering::matchAndRewriteMaskableOp(
614 vector::ContractionOp op, MaskingOpInterface maskOp,
615 PatternRewriter &rewriter) const {
616 // TODO: Support vector.mask.
617 if (maskOp)
618 return failure();
619
620 if (failed(filter(op)))
621 return failure();
622
623 if (vectorContractLowering != vector::VectorContractLowering::Dot)
624 return failure();
625
626 auto iteratorTypes = op.getIteratorTypes().getValue();
627 static constexpr std::array<int64_t, 2> perm = {1, 0};
628 Location loc = op.getLoc();
629 Value lhs = op.getLhs(), rhs = op.getRhs();
630
631 using MapList = ArrayRef<ArrayRef<AffineExpr>>;
632 auto infer = [&](MapList m) {
633 return AffineMap::inferFromExprList(m, op.getContext());
634 };
635 AffineExpr m, n, k;
636 bindDims(rewriter.getContext(), m, n, k);
637 SmallVector<AffineMap> maps = op.getIndexingMapsArray();
638 //
639 // In the following we wish to make the reduction dimension innermost so we
640 // can load vectors and just fmul + reduce into a scalar.
641 //
642 if (isParallelIterator(iteratorTypes[0]) &&
643 isParallelIterator(iteratorTypes[1]) &&
644 isReductionIterator(iteratorTypes[2])) {
645 //
646 // Two outer parallel, one inner reduction (matmat flavor).
647 //
648 if (maps == infer({{m, k}, {k, n}, {m, n}})) {
649 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
650 } else if (maps == infer({{m, k}, {n, k}, {m, n}})) {
651 // No need to permute anything.
652 } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
653 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
654 rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
655 } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
656 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
657 } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
658 // This is the classical row-major matmul. Just permute the lhs.
659 Value tmp = lhs;
660 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
661 rhs = tmp;
662 } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
663 std::swap(lhs, rhs);
664 } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
665 Value tmp = lhs;
666 lhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
667 rhs = vector::TransposeOp::create(rewriter, loc, tmp, perm);
668 } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
669 Value tmp = rhs;
670 rhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
671 lhs = tmp;
672 } else {
673 return failure();
674 }
675 } else if (isParallelIterator(iteratorTypes[0]) &&
676 isReductionIterator(iteratorTypes[1])) {
677 //
678 // One outer parallel, one inner reduction (matvec flavor)
679 //
680 if (maps == infer({{m, n}, {n}, {m}})) {
681 // No need to permute anything.
682 } else if (maps == infer({{n, m}, {n}, {m}})) {
683 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
684 } else if (maps == infer({{n}, {m, n}, {m}})) {
685 std::swap(lhs, rhs);
686 } else if (maps == infer({{n}, {n, m}, {m}})) {
687 std::swap(lhs, rhs);
688 lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
689 } else {
690 return failure();
691 }
692 } else {
693 return failure();
694 }
695
696 VectorType dstType = cast<VectorType>(op.getResultType());
697 assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 &&
698 "Expected dst type of rank 1 or 2");
699
700 unsigned rank = dstType.getRank();
701 unsigned dstRows = dstType.getShape()[0];
702 unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1];
703
704 // ExtractOp does not allow dynamic indexing, we must unroll explicitly.
705 Value res = arith::ConstantOp::create(rewriter, loc, dstType,
706 rewriter.getZeroAttr(dstType));
707 bool isInt = isa<IntegerType>(dstType.getElementType());
708 llvm::SmallVector<Value> extractedCols;
709 extractedCols.reserve(dstColumns);
710 for (unsigned r = 0; r < dstRows; ++r) {
711 Value rowLhs = vector::ExtractOp::create(rewriter, op.getLoc(), lhs, r);
712 for (unsigned c = 0; c < dstColumns; ++c) {
713 // Extract each respective row and column of the LHS and RHS once to
714 // avoid having duplicate SSA values pointing to the same rows/columns.
715 if (r == 0) {
716 Value colRhs =
717 rank == 1
718 ? rhs
719 : vector::ExtractOp::create(rewriter, op.getLoc(), rhs, c);
720 extractedCols.push_back(colRhs);
721 }
722 Value extractedColRhs = extractedCols[c];
723 Value product =
724 createMul(op.getLoc(), rowLhs, extractedColRhs, isInt, rewriter);
725 Value sum = vector::ReductionOp::create(
726 rewriter, op.getLoc(), vector::CombiningKind::ADD, product);
727
730 res = vector::InsertOp::create(rewriter, op.getLoc(), sum, res, pos);
731 }
732 }
733 if (auto acc = op.getAcc())
734 res = createAdd(op.getLoc(), res, acc, isInt, rewriter);
735 return res;
736}
737
738/// Lower vector.contract with all size one reduction dimensions to
739/// elementwise ops when possible.
741 : public MaskableOpRewritePattern<vector::ContractionOp> {
742 using MaskableOpRewritePattern::MaskableOpRewritePattern;
744 std::function<LogicalResult(vector::ContractionOp op)>;
745 static LogicalResult defaultFilter(vector::ContractionOp op) {
746 return success();
747 }
749 vector::VectorContractLowering vectorContractLowering,
750 MLIRContext *context, PatternBenefit benefit = 1,
751 const FilterConstraintType &constraint = defaultFilter)
752 : MaskableOpRewritePattern<vector::ContractionOp>(context, benefit),
753 vectorContractLowering(vectorContractLowering), filter(defaultFilter) {}
754
755 FailureOr<Value>
756 matchAndRewriteMaskableOp(vector::ContractionOp contractOp,
757 MaskingOpInterface maskOp,
758 PatternRewriter &rewriter) const override {
759 // TODO: Support vector.mask.
760 if (maskOp)
761 return failure();
762
763 if (failed(filter(contractOp)))
764 return failure();
765
766 if (vectorContractLowering != vector::VectorContractLowering::ParallelArith)
767 return failure();
768
769 ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
770 ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
771 AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
772 AffineMap rhsMap = contractOp.getIndexingMapsArray()[1];
773 SmallVector<int64_t> lhsReductionDims =
774 getReductionIndex(lhsMap, contractOp.getIteratorTypes());
775 SmallVector<int64_t> rhsReductionDims =
776 getReductionIndex(rhsMap, contractOp.getIteratorTypes());
777 // All the reduction dimensions must be a size 1.
778 for (int64_t dim : lhsReductionDims) {
779 if (lhsShape[dim] != 1)
780 return failure();
781 }
782 for (int64_t dim : rhsReductionDims) {
783 if (rhsShape[dim] != 1)
784 return failure();
785 }
786 AffineMap accMap = contractOp.getIndexingMapsArray()[2];
787 unsigned numParallelDims = accMap.getNumResults();
788 unsigned numLhsDimToBroadcast =
789 numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size());
790 unsigned numRhsDimToBroadcast =
791 numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size());
792 SmallVector<int64_t> lhsDims;
793 SmallVector<int64_t> lhsTranspose;
794 SmallVector<int64_t> rhsDims;
795 SmallVector<int64_t> rhsTranspose;
796 for (int64_t dim : lhsReductionDims)
797 lhsTranspose.push_back(numLhsDimToBroadcast + dim);
798 for (int64_t dim : rhsReductionDims)
799 rhsTranspose.push_back(numRhsDimToBroadcast + dim);
800 // Loop through the parallel dimensions to calculate the dimensions to
801 // broadcast and to permute in order to extract only parallel dimensions.
802 for (unsigned i = 0; i < numParallelDims; i++) {
803 std::optional<unsigned> lhsDim =
804 getDimPosition(lhsMap, accMap.getDimPosition(i));
805 if (lhsDim) {
806 lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim);
807 } else {
808 // If the parallel dimension doesn't exist we will have to broadcast it.
809 lhsDims.push_back(
810 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
811 lhsTranspose.push_back(lhsDims.size() - 1);
812 }
813 std::optional<unsigned> rhsDim =
814 getDimPosition(rhsMap, accMap.getDimPosition(i));
815 if (rhsDim) {
816 rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim);
817 } else {
818 // If the parallel dimension doesn't exist we will have to broadcast it.
819 rhsDims.push_back(
820 cast<VectorType>(contractOp.getResultType()).getDimSize(i));
821 rhsTranspose.push_back(rhsDims.size() - 1);
822 }
823 }
824 Value newLhs = contractOp.getLhs();
825 Value newRhs = contractOp.getRhs();
826 Location loc = contractOp.getLoc();
827 if (!lhsDims.empty()) {
828 lhsDims.append(lhsShape.begin(), lhsShape.end());
829 auto expandedType =
830 VectorType::get(lhsDims, contractOp.getLhsType().getElementType());
831 newLhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newLhs);
832 }
833 if (!rhsDims.empty()) {
834 rhsDims.append(rhsShape.begin(), rhsShape.end());
835 auto expandedType =
836 VectorType::get(rhsDims, contractOp.getRhsType().getElementType());
837 newRhs = vector::BroadcastOp::create(rewriter, loc, expandedType, newRhs);
838 }
839 bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex();
840 newLhs = vector::TransposeOp::create(rewriter, loc, newLhs, lhsTranspose);
841 newRhs = vector::TransposeOp::create(rewriter, loc, newRhs, rhsTranspose);
842 SmallVector<int64_t> lhsOffsets(lhsReductionDims.size(), 0);
843 SmallVector<int64_t> rhsOffsets(rhsReductionDims.size(), 0);
844 newLhs = vector::ExtractOp::create(rewriter, loc, newLhs, lhsOffsets);
845 newRhs = vector::ExtractOp::create(rewriter, loc, newRhs, rhsOffsets);
846 std::optional<Value> result =
847 createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(),
848 contractOp.getKind(), rewriter, isInt);
849 if (result)
850 return *result;
851
852 return failure();
853 }
854
855private:
856 /// Options to control the vector patterns.
857 vector::VectorContractLowering vectorContractLowering;
859};
860
861/// Progressive lowering of ContractionOp.
862/// One:
863/// %x = vector.contract with at least one free/batch dimension
864/// is replaced by:
865/// %a = vector.contract with one less free/batch dimension
866/// %b = vector.contract with one less free/batch dimension
867/// ..
868/// %x = combine %a %b ..
869/// until a pure contraction is reached (no free/batch dimensions),
870/// which is replaced by a dot-product.
871///
872/// This only kicks in when either vectorContractLoweringOption is set
873/// to DOT or when other contraction patterns fail.
874//
875// TODO: break down into transpose/reshape/cast ops
876// when they become available to avoid code dup
877// TODO: investigate lowering order impact on performance
878FailureOr<Value> ContractionOpLowering::matchAndRewriteMaskableOp(
879 vector::ContractionOp op, MaskingOpInterface maskOp,
880 PatternRewriter &rewriter) const {
881 if (failed(filter(op)))
882 return failure();
883
884 // TODO: support mixed mode contract lowering.
885 if (op.getLhsType().getElementType() !=
886 getElementTypeOrSelf(op.getAccType()) ||
887 op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType()))
888 return failure();
889
890 // TODO: the code below assumes the default contraction, make sure it supports
891 // other kinds before enabling this lowering.
892 if (op.getKind() != vector::CombiningKind::ADD) {
893 return rewriter.notifyMatchFailure(
894 op, "contractions other than 'add' not supported");
895 }
896
897 // TODO: implement benefits, cost models.
898 MLIRContext *ctx = op.getContext();
899
900 ContractionOpToOuterProductOpLowering pat1(vectorContractLoweringOption, ctx);
901 FailureOr<Value> newVal1 =
902 pat1.matchAndRewriteMaskableOp(op, maskOp, rewriter);
903 if (!failed(newVal1))
904 return newVal1;
905
906 ContractionOpToDotLowering pat2(vectorContractLoweringOption, ctx);
907 FailureOr<Value> newVal2 =
908 pat2.matchAndRewriteMaskableOp(op, maskOp, rewriter);
909 if (!failed(newVal2))
910 return newVal2;
911
912 ContractOpToElementwise pat4(vectorContractLoweringOption, ctx);
913 FailureOr<Value> newVal4 =
914 pat4.matchAndRewriteMaskableOp(op, maskOp, rewriter);
915 if (!failed(newVal4))
916 return newVal4;
917
918 // Vector mask setup.
919
920 Value mask;
921 if (maskOp)
922 mask = maskOp.getMask();
923 // Find first batch dimension in LHS/RHS, and lower when found.
924 std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
925 if (!batchDimMap.empty()) {
926 int64_t lhsIndex = batchDimMap[0].first;
927 int64_t rhsIndex = batchDimMap[0].second;
928 auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
929 if (failed(newOp))
930 return failure();
931 return newOp;
932 }
933
934 // Collect contracting dimensions.
935 std::vector<std::pair<int64_t, int64_t>> contractingDimMap =
936 op.getContractingDimMap();
937 DenseSet<int64_t> lhsContractingDimSet;
938 DenseSet<int64_t> rhsContractingDimSet;
939 for (auto &dimPair : contractingDimMap) {
940 lhsContractingDimSet.insert(dimPair.first);
941 rhsContractingDimSet.insert(dimPair.second);
942 }
943
944 // Find first free dimension in LHS, and lower when found.
945 VectorType lhsType = op.getLhsType();
946 for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
947 if (lhsContractingDimSet.count(lhsIndex) == 0) {
948 auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
949 if (failed(newOp))
950 return failure();
951 return newOp;
952 }
953 }
954
955 // Find first free dimension in RHS, and lower when found.
956 VectorType rhsType = op.getRhsType();
957 for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
958 if (rhsContractingDimSet.count(rhsIndex) == 0) {
959 auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
960 if (failed(newOp))
961 return failure();
962 return newOp;
963 }
964 }
965
966 // Lower the first remaining reduction dimension.
967 if (!contractingDimMap.empty()) {
968 auto newOp = lowerReduction(rewriter, op, mask);
969 if (failed(newOp))
970 return failure();
971 return newOp;
972 }
973
974 return failure();
975}
976
977// Lower one parallel dimension.
978// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
979// TODO: consider reusing existing contract unrolling
980FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
981 vector::ContractionOp op,
982 int64_t lhsIndex,
983 int64_t rhsIndex,
984 Value mask) const {
985 VectorType lhsType = op.getLhsType();
986 VectorType rhsType = op.getRhsType();
987 VectorType resType = cast<VectorType>(op.getResultType());
988 // Find the iterator type index and result index.
989 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
990 int64_t iterIndex = -1;
991 int64_t dimSize = -1;
992 if (lhsIndex >= 0) {
993 iterIndex = iMap[0].getDimPosition(lhsIndex);
994 if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex))
995 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
996 diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
997 << " to map to the same dimension";
998 });
999 if (lhsType.getScalableDims()[lhsIndex])
1000 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1001 diag << "Unrolling scalable dimension (lhsIndex=" << lhsIndex
1002 << ") is not supported yet";
1003 });
1004 dimSize = lhsType.getDimSize(lhsIndex);
1005 } else if (rhsIndex >= 0) {
1006 iterIndex = iMap[1].getDimPosition(rhsIndex);
1007 if (rhsType.getScalableDims()[rhsIndex])
1008 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1009 diag << "Unrolling scalable dimension (rhsIndex=" << rhsIndex
1010 << ") is not supported yet";
1011 });
1012 dimSize = rhsType.getDimSize(rhsIndex);
1013 }
1014 if (iterIndex < 0)
1015 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1016 diag << "expected either lhsIndex=" << lhsIndex
1017 << " or rhsIndex=" << rhsIndex << " to be nonnegative";
1018 });
1019 // value_or(-1) means that we tolerate a dimension not appearing
1020 // in the result map. That can't happen for actual parallel iterators, but
1021 // the caller ContractionOpLowering::matchAndRewrite is currently calling
1022 // lowerParallel also for the case of unit-size reduction dims appearing only
1023 // on one of LHS or RHS, not both. At the moment, such cases are created by
1024 // CastAwayContractionLeadingOneDim, so we need to either support that or
1025 // modify that pattern.
1026 int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1);
1027 if (resIndex == -1 && dimSize != 1)
1028 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1029 diag << "expected the dimension for iterIndex=" << iterIndex
1030 << " to either appear in the result map, or to be a unit dimension";
1031 });
1032
1033 // Construct new iterator types and affine map array attribute.
1034 std::array<AffineMap, 3> lowIndexingMaps = {
1035 adjustMap(iMap[0], iterIndex, rewriter),
1036 adjustMap(iMap[1], iterIndex, rewriter),
1037 adjustMap(iMap[2], iterIndex, rewriter)};
1038 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1039 auto lowIter =
1040 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1041 // Unroll into a series of lower dimensional vector.contract ops.
1042 Location loc = op.getLoc();
1043 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1044 rewriter.getZeroAttr(resType));
1045
1046 for (int64_t d = 0; d < dimSize; ++d) {
1047 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1048 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1049 auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
1050
1051 Value lowMask;
1052 if (mask)
1053 lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1054 iterIndex, d, rewriter);
1055
1056 Operation *lowContract = vector::ContractionOp::create(
1057 rewriter, loc, lhs, rhs, acc, lowAffine, lowIter);
1058 lowContract = maskOperation(rewriter, lowContract, lowMask);
1059 result = reshapeStore(loc, lowContract->getResult(0), result, resType,
1060 resIndex, d, rewriter);
1061 }
1062 return result;
1063}
1064
1065// Lower one reduction dimension.
1066FailureOr<Value> ContractionOpLowering::lowerReduction(
1067 PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
1068 auto loc = op.getLoc();
1069 VectorType lhsType = op.getLhsType();
1070 VectorType rhsType = op.getRhsType();
1071 Type resType = op.getResultType();
1072 if (isa<VectorType>(resType))
1073 return rewriter.notifyMatchFailure(op,
1074 "did not expect a VectorType result");
1075 bool isInt = isa<IntegerType>(resType);
1076 // Use iterator index 0.
1077 int64_t iterIndex = 0;
1078 SmallVector<AffineMap> iMap = op.getIndexingMapsArray();
1079 std::optional<int64_t> lookupLhs = getResultIndex(iMap[0], iterIndex);
1080 std::optional<int64_t> lookupRhs = getResultIndex(iMap[1], iterIndex);
1081 if (!lookupLhs.has_value())
1082 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1083 diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension";
1084 });
1085 if (!lookupRhs.has_value())
1086 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1087 diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension";
1088 });
1089 int64_t lhsIndex = *lookupLhs;
1090 int64_t rhsIndex = *lookupRhs;
1091 int64_t dimSize = lhsType.getDimSize(lhsIndex);
1092 if (dimSize != rhsType.getDimSize(rhsIndex))
1093 return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1094 diag << "expect LHS dimension " << lhsIndex
1095 << " to have the same size as RHS dimension " << rhsIndex;
1096 });
1097 // Base case.
1098 if (lhsType.getRank() == 1) {
1099 if (rhsType.getRank() != 1)
1100 return rewriter.notifyMatchFailure(
1101 op, "When LHS has rank 1, expected also RHS to have rank 1");
1102 Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
1103 auto kind = vector::CombiningKind::ADD;
1104
1105 Value acc = op.getAcc();
1106 Operation *reductionOp =
1107 acc ? vector::ReductionOp::create(rewriter, loc, kind, m, acc)
1108 : vector::ReductionOp::create(rewriter, loc, kind, m);
1109 return maskOperation(rewriter, reductionOp, mask)->getResult(0);
1110 }
1111 // Construct new iterator types and affine map array attribute.
1112 std::array<AffineMap, 3> lowIndexingMaps = {
1113 adjustMap(iMap[0], iterIndex, rewriter),
1114 adjustMap(iMap[1], iterIndex, rewriter),
1115 adjustMap(iMap[2], iterIndex, rewriter)};
1116 auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps);
1117 auto lowIter =
1118 rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex));
1119 // Unroll into a series of lower dimensional vector.contract ops.
1120 // By feeding the initial accumulator into the first contraction,
1121 // and the result of each contraction into the next, eventually
1122 // the sum of all reductions is computed.
1123 Value result = op.getAcc();
1124 for (int64_t d = 0; d < dimSize; ++d) {
1125 auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
1126 auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
1127 Value newMask;
1128 if (mask)
1129 newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
1130 iterIndex, d, rewriter);
1131
1132 Operation *newContract = vector::ContractionOp::create(
1133 rewriter, loc, lhs, rhs, result, lowAffine, lowIter);
1134 result = maskOperation(rewriter, newContract, newMask)->getResult(0);
1135 }
1136 return result;
1137}
1138
1139/// Progressive lowering of OuterProductOp.
1140/// One:
1141/// %x = vector.outerproduct %lhs, %rhs, %acc
1142/// is replaced by:
1143/// %z = zero-result
1144/// %0 = vector.extract %lhs[0]
1145/// %1 = vector.broadcast %0
1146/// %2 = vector.extract %acc[0]
1147/// %3 = vector.fma %1, %rhs, %2
1148/// %4 = vector.insert %3, %z[0]
1149/// ..
1150/// %x = vector.insert %.., %..[N-1]
1151///
1152class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
1153public:
1154 using Base::Base;
1155
1156 LogicalResult matchAndRewrite(vector::OuterProductOp op,
1157 PatternRewriter &rewriter) const override {
1158 VectorType resType = op.getResultVectorType();
1159 if ((resType.getShape().size() >= 2) && resType.allDimsScalable())
1160 return failure();
1161
1162 auto loc = op.getLoc();
1163
1164 VectorType lhsType = op.getOperandVectorTypeLHS();
1165 VectorType rhsType = dyn_cast<VectorType>(op.getOperandTypeRHS());
1166 Type eltType = resType.getElementType();
1167 bool isInt = isa<IntegerType, IndexType>(eltType);
1168 Value acc = op.getAcc();
1169 vector::CombiningKind kind = op.getKind();
1170
1171 // Vector mask setup.
1172 OpBuilder::InsertionGuard guard(rewriter);
1173 auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
1174 Operation *rootOp;
1175 Value mask;
1176 if (maskableOp.isMasked()) {
1177 rewriter.setInsertionPoint(maskableOp.getMaskingOp());
1178 rootOp = maskableOp.getMaskingOp();
1179 mask = maskableOp.getMaskingOp().getMask();
1180 } else {
1181 rootOp = op;
1182 }
1183
1184 if (!rhsType) {
1185 // Special case: AXPY operation.
1186 Value b =
1187 vector::BroadcastOp::create(rewriter, loc, lhsType, op.getRhs());
1188 std::optional<Value> mult = createContractArithOp(
1189 loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
1190 if (!mult.has_value())
1191 return failure();
1192 rewriter.replaceOp(rootOp, *mult);
1193 return success();
1194 }
1195
1196 Value result = arith::ConstantOp::create(rewriter, loc, resType,
1197 rewriter.getZeroAttr(resType));
1198 for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
1199 Value x = vector::ExtractOp::create(rewriter, loc, op.getLhs(), d);
1200 Value a = vector::BroadcastOp::create(rewriter, loc, rhsType, x);
1201 Value r = nullptr;
1202 if (acc)
1203 r = vector::ExtractOp::create(rewriter, loc, acc, d);
1204 Value extrMask;
1205 if (mask)
1206 extrMask = vector::ExtractOp::create(rewriter, loc, mask, d);
1207
1208 std::optional<Value> m = createContractArithOp(
1209 loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
1210 if (!m.has_value())
1211 return failure();
1212 result = vector::InsertOp::create(rewriter, loc, *m, result, d);
1213 }
1214
1215 rewriter.replaceOp(rootOp, result);
1216 return success();
1217 }
1218};
1219
1220} // namespace
1221
1224 VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit,
1225 bool disableOuterProductLowering) {
1226 if (!disableOuterProductLowering)
1227 patterns.add<OuterProductOpLowering>(patterns.getContext(), benefit);
1228 patterns.add<ContractionOpLowering, ContractionOpToOuterProductOpLowering>(
1229 vectorContractLoweringOption, patterns.getContext(), benefit);
1230}
1231
return success()
static int64_t product(ArrayRef< int64_t > vals)
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
auto load
static std::optional< Value > createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, bool isInt, Value mask=Value())
Helper to create arithmetic operation associated with a kind of contraction.
static std::optional< int64_t > getResultIndex(AffineMap map, int64_t index)
static SmallVector< int64_t > getReductionIndex(AffineMap map, ArrayAttr iteratorTypes)
Return the positions of the reductions in the given map.
static Value createAdd(Location loc, Value x, Value y, bool isInt, PatternRewriter &rewriter)
Creates an AddIOp if isInt is true otherwise create an arith::AddFOp using operands x and y.
static std::optional< unsigned > getDimPosition(AffineMap map, unsigned dim)
Look for a given dimension in an affine map and return its position.
static SmallVector< Attribute > adjustIter(ArrayAttr iteratorTypes, int64_t index)
static Value reshapeStore(Location loc, Value val, Value result, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static AffineMap adjustMap(AffineMap map, int64_t index, PatternRewriter &rewriter)
static Value reshapeLoad(Location loc, Value val, VectorType type, int64_t index, int64_t pos, PatternRewriter &rewriter)
static std::string diag(const llvm::Value &value)
#define mul(a, b)
Progressive lowering of OuterProductOp.
LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
unsigned getNumResults() const
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr > > exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition Builders.cpp:318
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Helper StructuredGenerator class to manipulate and rewrite ops with StructuredOpInterface.
bool iters(ArrayRef< IteratorType > its)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
This is a builder type that keeps local references to arguments.
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void promote(RewriterBase &rewriter, scf::ForallOp forallOp)
Promotes the loop body of a scf::ForallOp to its containing block.
Definition SCF.cpp:792
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition VectorOps.h:156
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition VectorOps.h:151
void populateVectorOuterProductLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
void populateVectorContractLoweringPatterns(RewritePatternSet &patterns, VectorContractLowering vectorContractLoweringOption, PatternBenefit benefit=1, bool disableOuterProductLowering=false)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Lower vector.contract with all size one reduction dimensions to elementwise ops when possible.
FailureOr< Value > matchAndRewriteMaskableOp(vector::ContractionOp contractOp, MaskingOpInterface maskOp, PatternRewriter &rewriter) const override
static LogicalResult defaultFilter(vector::ContractionOp op)
ContractOpToElementwise(vector::VectorContractLowering vectorContractLowering, MLIRContext *context, PatternBenefit benefit=1, const FilterConstraintType &constraint=defaultFilter)
std::function< LogicalResult(vector::ContractionOp op)> FilterConstraintType
OpRewritePattern Base
Type alias to allow derived classes to inherit constructors with using Base::Base;.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
A pattern for ops that implement MaskableOpInterface and that might be masked (i.e.
virtual FailureOr< Value > matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp, PatternRewriter &rewriter) const =0