MLIR  20.0.0git
LowerVectorScan.cpp
Go to the documentation of this file.
1 //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.scan' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
34 
35 #define DEBUG_TYPE "vector-broadcast-lowering"
36 
37 using namespace mlir;
38 using namespace mlir::vector;
39 
40 /// This function checks to see if the vector combining kind
41 /// is consistent with the integer or float element type.
42 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
43  using vector::CombiningKind;
44  enum class KindType { FLOAT, INT, INVALID };
45  KindType type{KindType::INVALID};
46  switch (kind) {
47  case CombiningKind::MINNUMF:
48  case CombiningKind::MINIMUMF:
49  case CombiningKind::MAXNUMF:
50  case CombiningKind::MAXIMUMF:
51  type = KindType::FLOAT;
52  break;
54  case CombiningKind::MINSI:
55  case CombiningKind::MAXUI:
56  case CombiningKind::MAXSI:
57  case CombiningKind::AND:
58  case CombiningKind::OR:
59  case CombiningKind::XOR:
60  type = KindType::INT;
61  break;
62  case CombiningKind::ADD:
63  case CombiningKind::MUL:
64  type = isInt ? KindType::INT : KindType::FLOAT;
65  break;
66  }
67  bool isValidIntKind = (type == KindType::INT) && isInt;
68  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
69  return (isValidIntKind || isValidFloatKind);
70 }
71 
72 namespace {
73 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
74 /// vector.extract_strided_slice.
75 ///
76 /// Example:
77 ///
78 /// ```
79 /// %0:2 = vector.scan <add>, %arg0, %arg1
80 /// {inclusive = true, reduction_dim = 1} :
81 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
82 /// ```
83 ///
84 /// is converted to:
85 ///
86 /// ```
87 /// %cst = arith.constant dense<0> : vector<2x3xi32>
88 /// %0 = vector.extract_strided_slice %arg0
89 /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
90 /// : vector<2x3xi32> to vector<2x1xi32>
91 /// %1 = vector.insert_strided_slice %0, %cst
92 /// {offsets = [0, 0], strides = [1, 1]}
93 /// : vector<2x1xi32> into vector<2x3xi32>
94 /// %2 = vector.extract_strided_slice %arg0
95 /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
96 /// : vector<2x3xi32> to vector<2x1xi32>
97 /// %3 = arith.muli %0, %2 : vector<2x1xi32>
98 /// %4 = vector.insert_strided_slice %3, %1
99 /// {offsets = [0, 1], strides = [1, 1]}
100 /// : vector<2x1xi32> into vector<2x3xi32>
101 /// %5 = vector.extract_strided_slice %arg0
102 /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
103 /// : vector<2x3xi32> to vector<2x1xi32>
104 /// %6 = arith.muli %3, %5 : vector<2x1xi32>
105 /// %7 = vector.insert_strided_slice %6, %4
106 /// {offsets = [0, 2], strides = [1, 1]}
107 /// : vector<2x1xi32> into vector<2x3xi32>
108 /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
109 /// return %7, %8 : vector<2x3xi32>, vector<2xi32>
110 /// ```
111 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
113 
114  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
115  PatternRewriter &rewriter) const override {
116  auto loc = scanOp.getLoc();
117  VectorType destType = scanOp.getDestType();
118  ArrayRef<int64_t> destShape = destType.getShape();
119  auto elType = destType.getElementType();
120  bool isInt = elType.isIntOrIndex();
121  if (!isValidKind(isInt, scanOp.getKind()))
122  return failure();
123 
124  VectorType resType = VectorType::get(destShape, elType);
125  Value result = rewriter.create<arith::ConstantOp>(
126  loc, resType, rewriter.getZeroAttr(resType));
127  int64_t reductionDim = scanOp.getReductionDim();
128  bool inclusive = scanOp.getInclusive();
129  int64_t destRank = destType.getRank();
130  VectorType initialValueType = scanOp.getInitialValueType();
131  int64_t initialValueRank = initialValueType.getRank();
132 
133  SmallVector<int64_t> reductionShape(destShape.begin(), destShape.end());
134  reductionShape[reductionDim] = 1;
135  VectorType reductionType = VectorType::get(reductionShape, elType);
136  SmallVector<int64_t> offsets(destRank, 0);
137  SmallVector<int64_t> strides(destRank, 1);
138  SmallVector<int64_t> sizes(destShape.begin(), destShape.end());
139  sizes[reductionDim] = 1;
140  ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
141  ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
142 
143  Value lastOutput, lastInput;
144  for (int i = 0; i < destShape[reductionDim]; i++) {
145  offsets[reductionDim] = i;
146  ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
147  Value input = rewriter.create<vector::ExtractStridedSliceOp>(
148  loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes,
149  scanStrides);
150  Value output;
151  if (i == 0) {
152  if (inclusive) {
153  output = input;
154  } else {
155  if (initialValueRank == 0) {
156  // ShapeCastOp cannot handle 0-D vectors
157  output = rewriter.create<vector::BroadcastOp>(
158  loc, input.getType(), scanOp.getInitialValue());
159  } else {
160  output = rewriter.create<vector::ShapeCastOp>(
161  loc, input.getType(), scanOp.getInitialValue());
162  }
163  }
164  } else {
165  Value y = inclusive ? input : lastInput;
166  output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
167  lastOutput, y);
168  }
169  result = rewriter.create<vector::InsertStridedSliceOp>(
170  loc, output, result, offsets, strides);
171  lastOutput = output;
172  lastInput = input;
173  }
174 
175  Value reduction;
176  if (initialValueRank == 0) {
177  Value v = rewriter.create<vector::ExtractOp>(loc, lastOutput, 0);
178  reduction =
179  rewriter.create<vector::BroadcastOp>(loc, initialValueType, v);
180  } else {
181  reduction = rewriter.create<vector::ShapeCastOp>(loc, initialValueType,
182  lastOutput);
183  }
184 
185  rewriter.replaceOp(scanOp, {result, reduction});
186  return success();
187  }
188 };
189 } // namespace
190 
192  RewritePatternSet &patterns, PatternBenefit benefit) {
193  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
194 }
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
#define MINUI(lhs, rhs)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:335
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:292
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362