MLIR 22.0.0git
LowerVectorInterleave.cpp
Go to the documentation of this file.
1//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.interleave' operation.
11//
12//===----------------------------------------------------------------------===//
13
19
20#define DEBUG_TYPE "vector-interleave-lowering"
21
22using namespace mlir;
23using namespace mlir::vector;
24
25namespace {
26
27/// A one-shot unrolling of vector.interleave to the `targetRank`.
28///
29/// Example:
30///
31/// ```mlir
32/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
33/// ```
34/// Would be unrolled to:
35/// ```mlir
36/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
37/// %0 = vector.extract %a[0, 0, 0] ─┐
38/// : vector<4xi64> from vector<1x2x3x4xi64> |
39/// %1 = vector.extract %b[0, 0, 0] |
40/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
41/// %2 = vector.interleave %0, %1 : | all leading positions
42/// : vector<4xi64> -> vector<8xi64> |
43/// %3 = vector.insert %2, %result [0, 0, 0] |
44/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
45/// ```
46///
47/// Note: If any leading dimension before the `targetRank` is scalable the
48/// unrolling will stop before the scalable dimension.
49class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
50public:
51 UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
52 PatternBenefit benefit = 1)
53 : OpRewritePattern(context, benefit), targetRank(targetRank){};
54
55 LogicalResult matchAndRewrite(vector::InterleaveOp op,
56 PatternRewriter &rewriter) const override {
57 VectorType resultType = op.getResultVectorType();
58 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
59 if (!unrollIterator)
60 return failure();
61
62 auto loc = op.getLoc();
63 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
64 rewriter.getZeroAttr(resultType));
65 for (auto position : *unrollIterator) {
66 Value extractLhs =
67 ExtractOp::create(rewriter, loc, op.getLhs(), position);
68 Value extractRhs =
69 ExtractOp::create(rewriter, loc, op.getRhs(), position);
70 Value interleave =
71 InterleaveOp::create(rewriter, loc, extractLhs, extractRhs);
72 result = InsertOp::create(rewriter, loc, interleave, result, position);
73 }
74
75 rewriter.replaceOp(op, result);
76 return success();
77 }
78
79private:
80 int64_t targetRank = 1;
81};
82
83/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
84///
85/// Example:
86///
87/// ```mlir
88/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
89/// ```
90/// Would be unrolled to:
91/// ```mlir
92/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
93/// %0 = vector.extract %a[0, 0, 0] ─┐
94/// : vector<8xi64> from vector<1x2x3x8xi64> |
95/// %1, %2 = vector.deinterleave %0 |
96/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave
97/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled.
98/// : vector<4xi64> into vector<1x2x3x4xi64> |
99/// %4 = vector.insert %2, %result [0, 0, 0] |
100/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
101/// %5 = vector.extract %a[0, 0, 1] ─┐
102/// : vector<8xi64> from vector<1x2x3x8xi64> |
103/// %6, %7 = vector.deinterleave %5 |
104/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for
105/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled
106/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave
107/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated
108/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case.
109/// ```
110///
111/// Note: If any leading dimension before the `targetRank` is scalable the
112/// unrolling will stop before the scalable dimension.
113class UnrollDeinterleaveOp final
114 : public OpRewritePattern<vector::DeinterleaveOp> {
115public:
116 UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
117 PatternBenefit benefit = 1)
118 : OpRewritePattern(context, benefit), targetRank(targetRank) {};
119
120 LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
121 PatternRewriter &rewriter) const override {
122 VectorType resultType = op.getResultVectorType();
123 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
124 if (!unrollIterator)
125 return failure();
126
127 auto loc = op.getLoc();
128 Value emptyResult = arith::ConstantOp::create(
129 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
130 Value evenResult = emptyResult;
131 Value oddResult = emptyResult;
132
133 for (auto position : *unrollIterator) {
134 auto extractSrc =
135 vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
136 auto deinterleave =
137 vector::DeinterleaveOp::create(rewriter, loc, extractSrc);
138 evenResult = vector::InsertOp::create(
139 rewriter, loc, deinterleave.getRes1(), evenResult, position);
140 oddResult = vector::InsertOp::create(
141 rewriter, loc, deinterleave.getRes2(), oddResult, position);
142 }
143 rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
144 return success();
145 }
146
147private:
148 int64_t targetRank = 1;
149};
150/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
151/// applicable: `sourceType` must be 1D and non-scalable.
152///
153/// Example:
154///
155/// ```mlir
156/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
157/// ```
158///
159/// Is rewritten into:
160///
161/// ```mlir
162/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
163/// : vector<7xi16>, vector<7xi16>
164/// ```
165struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
166 using Base::Base;
167
168 LogicalResult matchAndRewrite(vector::InterleaveOp op,
169 PatternRewriter &rewriter) const override {
170 VectorType sourceType = op.getSourceVectorType();
171 if (sourceType.getRank() != 1 || sourceType.isScalable()) {
172 return failure();
173 }
174 int64_t n = sourceType.getNumElements();
175 auto seq = llvm::seq<int64_t>(2 * n);
176 auto zip = llvm::to_vector(llvm::map_range(
177 seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; }));
178 rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
179 return success();
180 }
181};
182
183} // namespace
184
186 RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
187 patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
188 targetRank, patterns.getContext(), benefit);
189}
190
193 patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
194}
return success()
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...