MLIR 23.0.0git
LowerVectorInterleave.cpp
Go to the documentation of this file.
1//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.interleave' operation.
11//
12//===----------------------------------------------------------------------===//
13
19#include "llvm/ADT/SmallVectorExtras.h"
20
21#define DEBUG_TYPE "vector-interleave-lowering"
22
23using namespace mlir;
24using namespace mlir::vector;
25
26namespace {
27
28/// A one-shot unrolling of vector.interleave to the `targetRank`.
29///
30/// Example:
31///
32/// ```mlir
33/// vector.interleave %a, %b : vector<1x2x3x4xi64> -> vector<1x2x3x8xi64>
34/// ```
35/// Would be unrolled to:
36/// ```mlir
37/// %result = arith.constant dense<0> : vector<1x2x3x8xi64>
38/// %0 = vector.extract %a[0, 0, 0] ─┐
39/// : vector<4xi64> from vector<1x2x3x4xi64> |
40/// %1 = vector.extract %b[0, 0, 0] |
41/// : vector<4xi64> from vector<1x2x3x4xi64> | - Repeated 6x for
42/// %2 = vector.interleave %0, %1 : | all leading positions
43/// : vector<4xi64> -> vector<8xi64> |
44/// %3 = vector.insert %2, %result [0, 0, 0] |
45/// : vector<8xi64> into vector<1x2x3x8xi64> ┘
46/// ```
47///
48/// Note: If any leading dimension before the `targetRank` is scalable the
49/// unrolling will stop before the scalable dimension.
50class UnrollInterleaveOp final : public OpRewritePattern<vector::InterleaveOp> {
51public:
52 UnrollInterleaveOp(int64_t targetRank, MLIRContext *context,
53 PatternBenefit benefit = 1)
54 : OpRewritePattern(context, benefit), targetRank(targetRank){};
55
56 LogicalResult matchAndRewrite(vector::InterleaveOp op,
57 PatternRewriter &rewriter) const override {
58 VectorType resultType = op.getResultVectorType();
59 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
60 if (!unrollIterator)
61 return failure();
62
63 auto loc = op.getLoc();
64 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
65 rewriter.getZeroAttr(resultType));
66 for (auto position : *unrollIterator) {
67 Value extractLhs =
68 ExtractOp::create(rewriter, loc, op.getLhs(), position);
69 Value extractRhs =
70 ExtractOp::create(rewriter, loc, op.getRhs(), position);
71 Value interleave =
72 InterleaveOp::create(rewriter, loc, extractLhs, extractRhs);
73 result = InsertOp::create(rewriter, loc, interleave, result, position);
74 }
75
76 rewriter.replaceOp(op, result);
77 return success();
78 }
79
80private:
81 int64_t targetRank = 1;
82};
83
84/// A one-shot unrolling of vector.deinterleave to the `targetRank`.
85///
86/// Example:
87///
88/// ```mlir
89/// %0, %1 = vector.deinterleave %a : vector<1x2x3x8xi64> -> vector<1x2x3x4xi64>
90/// ```
91/// Would be unrolled to:
92/// ```mlir
93/// %result = arith.constant dense<0> : vector<1x2x3x4xi64>
94/// %0 = vector.extract %a[0, 0, 0] ─┐
95/// : vector<8xi64> from vector<1x2x3x8xi64> |
96/// %1, %2 = vector.deinterleave %0 |
97/// : vector<8xi64> -> vector<4xi64> | -- Initial deinterleave
98/// %3 = vector.insert %1, %result [0, 0, 0] | operation unrolled.
99/// : vector<4xi64> into vector<1x2x3x4xi64> |
100/// %4 = vector.insert %2, %result [0, 0, 0] |
101/// : vector<4xi64> into vector<1x2x3x4xi64> ┘
102/// %5 = vector.extract %a[0, 0, 1] ─┐
103/// : vector<8xi64> from vector<1x2x3x8xi64> |
104/// %6, %7 = vector.deinterleave %5 |
105/// : vector<8xi64> -> vector<4xi64> | -- Recursive pattern for
106/// %8 = vector.insert %6, %3 [0, 0, 1] | subsequent unrolled
107/// : vector<4xi64> into vector<1x2x3x4xi64> | deinterleave
108/// %9 = vector.insert %7, %4 [0, 0, 1] | operations. Repeated
109/// : vector<4xi64> into vector<1x2x3x4xi64> ┘ 5x in this case.
110/// ```
111///
112/// Note: If any leading dimension before the `targetRank` is scalable the
113/// unrolling will stop before the scalable dimension.
114class UnrollDeinterleaveOp final
115 : public OpRewritePattern<vector::DeinterleaveOp> {
116public:
117 UnrollDeinterleaveOp(int64_t targetRank, MLIRContext *context,
118 PatternBenefit benefit = 1)
119 : OpRewritePattern(context, benefit), targetRank(targetRank) {};
120
121 LogicalResult matchAndRewrite(vector::DeinterleaveOp op,
122 PatternRewriter &rewriter) const override {
123 VectorType resultType = op.getResultVectorType();
124 auto unrollIterator = vector::createUnrollIterator(resultType, targetRank);
125 if (!unrollIterator)
126 return failure();
127
128 auto loc = op.getLoc();
129 Value emptyResult = arith::ConstantOp::create(
130 rewriter, loc, resultType, rewriter.getZeroAttr(resultType));
131 Value evenResult = emptyResult;
132 Value oddResult = emptyResult;
133
134 for (auto position : *unrollIterator) {
135 auto extractSrc =
136 vector::ExtractOp::create(rewriter, loc, op.getSource(), position);
137 auto deinterleave =
138 vector::DeinterleaveOp::create(rewriter, loc, extractSrc);
139 evenResult = vector::InsertOp::create(
140 rewriter, loc, deinterleave.getRes1(), evenResult, position);
141 oddResult = vector::InsertOp::create(
142 rewriter, loc, deinterleave.getRes2(), oddResult, position);
143 }
144 rewriter.replaceOp(op, ValueRange{evenResult, oddResult});
145 return success();
146 }
147
148private:
149 int64_t targetRank = 1;
150};
151/// Rewrite vector.interleave op into an equivalent vector.shuffle op, when
152/// applicable: `sourceType` must be 1D and non-scalable.
153///
154/// Example:
155///
156/// ```mlir
157/// vector.interleave %a, %b : vector<7xi16> -> vector<14xi16>
158/// ```
159///
160/// Is rewritten into:
161///
162/// ```mlir
163/// vector.shuffle %arg0, %arg1 [0, 7, 1, 8, 2, 9, 3, 10, 4, 11, 5, 12, 6, 13]
164/// : vector<7xi16>, vector<7xi16>
165/// ```
166struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
167 using Base::Base;
168
169 LogicalResult matchAndRewrite(vector::InterleaveOp op,
170 PatternRewriter &rewriter) const override {
171 VectorType sourceType = op.getSourceVectorType();
172 if (sourceType.getRank() != 1 || sourceType.isScalable()) {
173 return failure();
174 }
175 int64_t n = sourceType.getNumElements();
176 auto seq = llvm::seq<int64_t>(2 * n);
177 auto zip = llvm::map_to_vector(
178 seq, [n](int64_t i) { return (i % 2 ? n : 0) + i / 2; });
179 rewriter.replaceOpWithNewOp<ShuffleOp>(op, op.getLhs(), op.getRhs(), zip);
180 return success();
181 }
182};
183
184} // namespace
185
187 RewritePatternSet &patterns, int64_t targetRank, PatternBenefit benefit) {
188 patterns.add<UnrollInterleaveOp, UnrollDeinterleaveOp>(
189 targetRank, patterns.getContext(), benefit);
190}
191
194 patterns.add<InterleaveToShuffle>(patterns.getContext(), benefit);
195}
return success()
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:328
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
void populateVectorInterleaveToShufflePatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
std::optional< StaticTileOffsetRange > createUnrollIterator(VectorType vType, int64_t targetRank=1)
Returns an iterator for all positions in the leading dimensions of vType up to the targetRank.
void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, int64_t targetRank=1, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...