MLIR  19.0.0git
LowerVectorScan.cpp
Go to the documentation of this file.
1 //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.scan' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-broadcast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 /// This function checks to see if the vector combining kind
42 /// is consistent with the integer or float element type.
43 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
44  using vector::CombiningKind;
45  enum class KindType { FLOAT, INT, INVALID };
46  KindType type{KindType::INVALID};
47  switch (kind) {
48  case CombiningKind::MINNUMF:
49  case CombiningKind::MINIMUMF:
50  case CombiningKind::MAXNUMF:
51  case CombiningKind::MAXIMUMF:
52  type = KindType::FLOAT;
53  break;
55  case CombiningKind::MINSI:
56  case CombiningKind::MAXUI:
57  case CombiningKind::MAXSI:
58  case CombiningKind::AND:
59  case CombiningKind::OR:
60  case CombiningKind::XOR:
61  type = KindType::INT;
62  break;
63  case CombiningKind::ADD:
64  case CombiningKind::MUL:
65  type = isInt ? KindType::INT : KindType::FLOAT;
66  break;
67  }
68  bool isValidIntKind = (type == KindType::INT) && isInt;
69  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
70  return (isValidIntKind || isValidFloatKind);
71 }
72 
73 namespace {
74 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
75 /// vector.extract_strided_slice.
76 ///
77 /// Example:
78 ///
79 /// ```
80 /// %0:2 = vector.scan <add>, %arg0, %arg1
81 /// {inclusive = true, reduction_dim = 1} :
82 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
83 /// ```
84 ///
85 /// is converted to:
86 ///
87 /// ```
88 /// %cst = arith.constant dense<0> : vector<2x3xi32>
89 /// %0 = vector.extract_strided_slice %arg0
90 /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
91 /// : vector<2x3xi32> to vector<2x1xi32>
92 /// %1 = vector.insert_strided_slice %0, %cst
93 /// {offsets = [0, 0], strides = [1, 1]}
94 /// : vector<2x1xi32> into vector<2x3xi32>
95 /// %2 = vector.extract_strided_slice %arg0
96 /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
97 /// : vector<2x3xi32> to vector<2x1xi32>
98 /// %3 = arith.muli %0, %2 : vector<2x1xi32>
99 /// %4 = vector.insert_strided_slice %3, %1
100 /// {offsets = [0, 1], strides = [1, 1]}
101 /// : vector<2x1xi32> into vector<2x3xi32>
102 /// %5 = vector.extract_strided_slice %arg0
103 /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
104 /// : vector<2x3xi32> to vector<2x1xi32>
105 /// %6 = arith.muli %3, %5 : vector<2x1xi32>
106 /// %7 = vector.insert_strided_slice %6, %4
107 /// {offsets = [0, 2], strides = [1, 1]}
108 /// : vector<2x1xi32> into vector<2x3xi32>
109 /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
110 /// return %7, %8 : vector<2x3xi32>, vector<2xi32>
111 /// ```
112 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
114 
115  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
116  PatternRewriter &rewriter) const override {
117  auto loc = scanOp.getLoc();
118  VectorType destType = scanOp.getDestType();
119  ArrayRef<int64_t> destShape = destType.getShape();
120  auto elType = destType.getElementType();
121  bool isInt = elType.isIntOrIndex();
122  if (!isValidKind(isInt, scanOp.getKind()))
123  return failure();
124 
125  VectorType resType = VectorType::get(destShape, elType);
126  Value result = rewriter.create<arith::ConstantOp>(
127  loc, resType, rewriter.getZeroAttr(resType));
128  int64_t reductionDim = scanOp.getReductionDim();
129  bool inclusive = scanOp.getInclusive();
130  int64_t destRank = destType.getRank();
131  VectorType initialValueType = scanOp.getInitialValueType();
132  int64_t initialValueRank = initialValueType.getRank();
133 
134  SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
135  reductionShape[reductionDim] = 1;
136  VectorType reductionType = VectorType::get(reductionShape, elType);
137  SmallVector<int64_t> offsets(destRank, 0);
138  SmallVector<int64_t> strides(destRank, 1);
139  SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
140  sizes[reductionDim] = 1;
141  ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
142  ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
143 
144  Value lastOutput, lastInput;
145  for (int i = 0; i < destShape[reductionDim]; i++) {
146  offsets[reductionDim] = i;
147  ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
148  Value input = rewriter.create<vector::ExtractStridedSliceOp>(
149  loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
150  scanStrides);
151  Value output;
152  if (i == 0) {
153  if (inclusive) {
154  output = input;
155  } else {
156  if (initialValueRank == 0) {
157  // ShapeCastOp cannot handle 0-D vectors
158  output = rewriter.create<vector::BroadcastOp>(
159  loc, input.getType(), scanOp.getInitialValue());
160  } else {
161  output = rewriter.create<vector::ShapeCastOp>(
162  loc, input.getType(), scanOp.getInitialValue());
163  }
164  }
165  } else {
166  Value y = inclusive ? input : lastInput;
167  output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
168  lastOutput, y);
169  }
170  result = rewriter.create<vector::InsertStridedSliceOp>(
171  loc, output, result, offsets, strides);
172  lastOutput = output;
173  lastInput = input;
174  }
175 
176  Value reduction;
177  if (initialValueRank == 0) {
178  Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
179  reduction =
180  rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
181  } else {
182  reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
183  lastOutput);
184  }
185 
186  rewriter.replaceOp(scanOp, {result, reduction});
187  return success();
188  }
189 };
190 } // namespace
191 
193  RewritePatternSet &patterns, PatternBenefit benefit) {
194  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
195 }
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
#define MINUI(lhs, rhs)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362