MLIR  18.0.0git
LowerVectorScan.cpp
Go to the documentation of this file.
1 //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.scan' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-broadcast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// This function constructs the appropriate integer or float
42 /// operation given the vector combining kind and operands. The
43 /// supported int operations are : add, mul, min (signed/unsigned),
44 /// max(signed/unsigned), and, or, xor. The supported float
45 /// operations are : add, mul, min and max.
47  vector::CombiningKind kind,
48  PatternRewriter &rewriter) {
49  using vector::CombiningKind;
50 
51  auto elType = cast<VectorType>(x.getType()).getElementType();
52  bool isInt = elType.isIntOrIndex();
53 
54  Value combinedResult{nullptr};
55  switch (kind) {
56  case CombiningKind::ADD:
57  if (isInt)
58  combinedResult = rewriter.create<arith::AddIOp>(loc, x, y);
59  else
60  combinedResult = rewriter.create<arith::AddFOp>(loc, x, y);
61  break;
62  case CombiningKind::MUL:
63  if (isInt)
64  combinedResult = rewriter.create<arith::MulIOp>(loc, x, y);
65  else
66  combinedResult = rewriter.create<arith::MulFOp>(loc, x, y);
67  break;
68  case CombiningKind::MINUI:
69  combinedResult = rewriter.create<arith::MinUIOp>(loc, x, y);
70  break;
71  case CombiningKind::MINSI:
72  combinedResult = rewriter.create<arith::MinSIOp>(loc, x, y);
73  break;
74  case CombiningKind::MAXUI:
75  combinedResult = rewriter.create<arith::MaxUIOp>(loc, x, y);
76  break;
77  case CombiningKind::MAXSI:
78  combinedResult = rewriter.create<arith::MaxSIOp>(loc, x, y);
79  break;
80  case CombiningKind::AND:
81  combinedResult = rewriter.create<arith::AndIOp>(loc, x, y);
82  break;
83  case CombiningKind::OR:
84  combinedResult = rewriter.create<arith::OrIOp>(loc, x, y);
85  break;
86  case CombiningKind::XOR:
87  combinedResult = rewriter.create<arith::XOrIOp>(loc, x, y);
88  break;
89  case CombiningKind::MINF:
90  case CombiningKind::MINIMUMF:
91  combinedResult = rewriter.create<arith::MinimumFOp>(loc, x, y);
92  break;
93  case CombiningKind::MAXF:
94  case CombiningKind::MAXIMUMF:
95  combinedResult = rewriter.create<arith::MaximumFOp>(loc, x, y);
96  break;
97  }
98  return combinedResult;
99 }
100 
101 /// This function checks to see if the vector combining kind
102 /// is consistent with the integer or float element type.
103 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
104  using vector::CombiningKind;
105  enum class KindType { FLOAT, INT, INVALID };
106  KindType type{KindType::INVALID};
107  switch (kind) {
108  case CombiningKind::MINF:
109  case CombiningKind::MINIMUMF:
110  case CombiningKind::MAXF:
111  case CombiningKind::MAXIMUMF:
112  type = KindType::FLOAT;
113  break;
114  case CombiningKind::MINUI:
115  case CombiningKind::MINSI:
116  case CombiningKind::MAXUI:
117  case CombiningKind::MAXSI:
118  case CombiningKind::AND:
119  case CombiningKind::OR:
120  case CombiningKind::XOR:
121  type = KindType::INT;
122  break;
123  case CombiningKind::ADD:
124  case CombiningKind::MUL:
125  type = isInt ? KindType::INT : KindType::FLOAT;
126  break;
127  }
128  bool isValidIntKind = (type == KindType::INT) && isInt;
129  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
130  return (isValidIntKind || isValidFloatKind);
131 }
132 
133 namespace {
134 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
135 /// vector.extract_strided_slice.
136 ///
137 /// Example:
138 ///
139 /// ```
140 /// %0:2 = vector.scan <add>, %arg0, %arg1
141 /// {inclusive = true, reduction_dim = 1} :
142 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
143 /// ```
144 ///
145 /// is converted to:
146 ///
147 /// ```
148 /// %cst = arith.constant dense<0> : vector<2x3xi32>
149 /// %0 = vector.extract_strided_slice %arg0
150 /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
151 /// : vector<2x3xi32> to vector<2x1xi32>
152 /// %1 = vector.insert_strided_slice %0, %cst
153 /// {offsets = [0, 0], strides = [1, 1]}
154 /// : vector<2x1xi32> into vector<2x3xi32>
155 /// %2 = vector.extract_strided_slice %arg0
156 /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
157 /// : vector<2x3xi32> to vector<2x1xi32>
158 /// %3 = arith.muli %0, %2 : vector<2x1xi32>
159 /// %4 = vector.insert_strided_slice %3, %1
160 /// {offsets = [0, 1], strides = [1, 1]}
161 /// : vector<2x1xi32> into vector<2x3xi32>
162 /// %5 = vector.extract_strided_slice %arg0
163 /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
164 /// : vector<2x3xi32> to vector<2x1xi32>
165 /// %6 = arith.muli %3, %5 : vector<2x1xi32>
166 /// %7 = vector.insert_strided_slice %6, %4
167 /// {offsets = [0, 2], strides = [1, 1]}
168 /// : vector<2x1xi32> into vector<2x3xi32>
169 /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
170 /// return %7, %8 : vector<2x3xi32>, vector<2xi32>
171 /// ```
172 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
174 
175  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
176  PatternRewriter &rewriter) const override {
177  auto loc = scanOp.getLoc();
178  VectorType destType = scanOp.getDestType();
179  ArrayRef<int64_t> destShape = destType.getShape();
180  auto elType = destType.getElementType();
181  bool isInt = elType.isIntOrIndex();
182  if (!isValidKind(isInt, scanOp.getKind()))
183  return failure();
184 
185  VectorType resType = VectorType::get(destShape, elType);
186  Value result = rewriter.create<arith::ConstantOp>(
187  loc, resType, rewriter.getZeroAttr(resType));
188  int64_t reductionDim = scanOp.getReductionDim();
189  bool inclusive = scanOp.getInclusive();
190  int64_t destRank = destType.getRank();
191  VectorType initialValueType = scanOp.getInitialValueType();
192  int64_t initialValueRank = initialValueType.getRank();
193 
194  SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
195  reductionShape[reductionDim] = 1;
196  VectorType reductionType = VectorType::get(reductionShape, elType);
197  SmallVector<int64_t> offsets(destRank, 0);
198  SmallVector<int64_t> strides(destRank, 1);
199  SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
200  sizes[reductionDim] = 1;
201  ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
202  ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
203 
204  Value lastOutput, lastInput;
205  for (int i = 0; i < destShape[reductionDim]; i++) {
206  offsets[reductionDim] = i;
207  ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
208  Value input = rewriter.create<vector::ExtractStridedSliceOp>(
209  loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
210  scanStrides);
211  Value output;
212  if (i == 0) {
213  if (inclusive) {
214  output = input;
215  } else {
216  if (initialValueRank == 0) {
217  // ShapeCastOp cannot handle 0-D vectors
218  output = rewriter.create<vector::BroadcastOp>(
219  loc, input.getType(), scanOp.getInitialValue());
220  } else {
221  output = rewriter.create<vector::ShapeCastOp>(
222  loc, input.getType(), scanOp.getInitialValue());
223  }
224  }
225  } else {
226  Value y = inclusive ? input : lastInput;
227  output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter);
228  assert(output != nullptr);
229  }
230  result = rewriter.create<vector::InsertStridedSliceOp>(
231  loc, output, result, offsets, strides);
232  lastOutput = output;
233  lastInput = input;
234  }
235 
236  Value reduction;
237  if (initialValueRank == 0) {
238  Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
239  reduction =
240  rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
241  } else {
242  reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
243  lastOutput);
244  }
245 
246  rewriter.replaceOp(scanOp, {result, reduction});
247  return success();
248  }
249 };
250 } // namespace
251 
253  RewritePatternSet &patterns, PatternBenefit benefit) {
254  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
255 }
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
static Value genOperator(Location loc, Value x, Value y, vector::CombiningKind kind, PatternRewriter &rewriter)
This function constructs the appropriate integer or float operation given the vector combining kind a...
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361