MLIR 22.0.0git
LowerVectorScan.cpp
Go to the documentation of this file.
1//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements target-independent rewrites and utilities to lower the
10// 'vector.scan' operation.
11//
12//===----------------------------------------------------------------------===//
13
21#include "mlir/IR/Location.h"
24
25#define DEBUG_TYPE "vector-broadcast-lowering"
26
27using namespace mlir;
28using namespace mlir::vector;
29
30/// This function checks to see if the vector combining kind
31/// is consistent with the integer or float element type.
32static bool isValidKind(bool isInt, vector::CombiningKind kind) {
33 using vector::CombiningKind;
34 enum class KindType { FLOAT, INT, INVALID };
35 KindType type{KindType::INVALID};
36 switch (kind) {
37 case CombiningKind::MINNUMF:
38 case CombiningKind::MINIMUMF:
39 case CombiningKind::MAXNUMF:
40 case CombiningKind::MAXIMUMF:
41 type = KindType::FLOAT;
42 break;
43 case CombiningKind::MINUI:
44 case CombiningKind::MINSI:
45 case CombiningKind::MAXUI:
46 case CombiningKind::MAXSI:
47 case CombiningKind::AND:
48 case CombiningKind::OR:
49 case CombiningKind::XOR:
50 type = KindType::INT;
51 break;
52 case CombiningKind::ADD:
53 case CombiningKind::MUL:
54 type = isInt ? KindType::INT : KindType::FLOAT;
55 break;
56 }
57 bool isValidIntKind = (type == KindType::INT) && isInt;
58 bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
59 return (isValidIntKind || isValidFloatKind);
60}
61
62namespace {
63/// Convert vector.scan op into arith ops and vector.insert_strided_slice /
64/// vector.extract_strided_slice.
65///
66/// Example:
67///
68/// ```
69/// %0:2 = vector.scan <add>, %arg0, %arg1
70/// {inclusive = true, reduction_dim = 1} :
71/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
72/// ```
73///
74/// is converted to:
75///
76/// ```
77/// %cst = arith.constant dense<0> : vector<2x3xi32>
78/// %0 = vector.extract_strided_slice %arg0
79/// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
80/// : vector<2x3xi32> to vector<2x1xi32>
81/// %1 = vector.insert_strided_slice %0, %cst
82/// {offsets = [0, 0], strides = [1, 1]}
83/// : vector<2x1xi32> into vector<2x3xi32>
84/// %2 = vector.extract_strided_slice %arg0
85/// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
86/// : vector<2x3xi32> to vector<2x1xi32>
87/// %3 = arith.muli %0, %2 : vector<2x1xi32>
88/// %4 = vector.insert_strided_slice %3, %1
89/// {offsets = [0, 1], strides = [1, 1]}
90/// : vector<2x1xi32> into vector<2x3xi32>
91/// %5 = vector.extract_strided_slice %arg0
92/// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
93/// : vector<2x3xi32> to vector<2x1xi32>
94/// %6 = arith.muli %3, %5 : vector<2x1xi32>
95/// %7 = vector.insert_strided_slice %6, %4
96/// {offsets = [0, 2], strides = [1, 1]}
97/// : vector<2x1xi32> into vector<2x3xi32>
98/// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
99/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
100/// ```
101struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
102 using Base::Base;
103
104 LogicalResult matchAndRewrite(vector::ScanOp scanOp,
105 PatternRewriter &rewriter) const override {
106 auto loc = scanOp.getLoc();
107 VectorType destType = scanOp.getDestType();
108 ArrayRef<int64_t> destShape = destType.getShape();
109 auto elType = destType.getElementType();
110 bool isInt = elType.isIntOrIndex();
111 if (!isValidKind(isInt, scanOp.getKind()))
112 return failure();
113
114 VectorType resType = destType;
115 Value result = arith::ConstantOp::create(rewriter, loc, resType,
116 rewriter.getZeroAttr(resType));
117 int64_t reductionDim = scanOp.getReductionDim();
118 bool inclusive = scanOp.getInclusive();
119 int64_t destRank = destType.getRank();
120 VectorType initialValueType = scanOp.getInitialValueType();
121 int64_t initialValueRank = initialValueType.getRank();
122
123 SmallVector<int64_t> reductionShape(destShape);
124 SmallVector<bool> reductionScalableDims(destType.getScalableDims());
125
126 if (reductionScalableDims[reductionDim])
127 return rewriter.notifyMatchFailure(
128 scanOp, "Trying to reduce scalable dimension - not yet supported!");
129
130 // The reduction dimension, after reducing, becomes 1. It's a fixed-width
131 // dimension - no need to touch the scalability flag.
132 reductionShape[reductionDim] = 1;
133 VectorType reductionType =
134 VectorType::get(reductionShape, elType, reductionScalableDims);
135
136 SmallVector<int64_t> offsets(destRank, 0);
137 SmallVector<int64_t> strides(destRank, 1);
138 SmallVector<int64_t> sizes(destShape);
139 sizes[reductionDim] = 1;
140 ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
141 ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
142
143 Value lastOutput, lastInput;
144 for (int i = 0; i < destShape[reductionDim]; i++) {
145 offsets[reductionDim] = i;
146 ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
147 Value input = vector::ExtractStridedSliceOp::create(
148 rewriter, loc, reductionType, scanOp.getSource(), scanOffsets,
149 scanSizes, scanStrides);
150 Value output;
151 if (i == 0) {
152 if (inclusive) {
153 output = input;
154 } else {
155 if (initialValueRank == 0) {
156 // ShapeCastOp cannot handle 0-D vectors
157 output = vector::BroadcastOp::create(rewriter, loc, input.getType(),
158 scanOp.getInitialValue());
159 } else {
160 output = vector::ShapeCastOp::create(rewriter, loc, input.getType(),
161 scanOp.getInitialValue());
162 }
163 }
164 } else {
165 Value y = inclusive ? input : lastInput;
166 output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
167 lastOutput, y);
168 }
169 result = vector::InsertStridedSliceOp::create(rewriter, loc, output,
170 result, offsets, strides);
171 lastOutput = output;
172 lastInput = input;
173 }
174
175 Value reduction;
176 if (initialValueRank == 0) {
177 Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0);
178 reduction =
179 vector::BroadcastOp::create(rewriter, loc, initialValueType, v);
180 } else {
181 reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType,
182 lastOutput);
183 }
184
185 rewriter.replaceOp(scanOp, {result, reduction});
186 return success();
187 }
188};
189} // namespace
190
193 patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
194}
return success()
ArrayAttr()
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:281
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Type getType() const
Return the type of this value.
Definition Value.h:105
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...