MLIR  22.0.0git
LowerVectorScan.cpp
Go to the documentation of this file.
1 //===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.scan' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Location.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 
25 #define DEBUG_TYPE "vector-broadcast-lowering"
26 
27 using namespace mlir;
28 using namespace mlir::vector;
29 
30 /// This function checks to see if the vector combining kind
31 /// is consistent with the integer or float element type.
32 static bool isValidKind(bool isInt, vector::CombiningKind kind) {
33  using vector::CombiningKind;
34  enum class KindType { FLOAT, INT, INVALID };
35  KindType type{KindType::INVALID};
36  switch (kind) {
37  case CombiningKind::MINNUMF:
38  case CombiningKind::MINIMUMF:
39  case CombiningKind::MAXNUMF:
40  case CombiningKind::MAXIMUMF:
41  type = KindType::FLOAT;
42  break;
44  case CombiningKind::MINSI:
45  case CombiningKind::MAXUI:
46  case CombiningKind::MAXSI:
47  case CombiningKind::AND:
48  case CombiningKind::OR:
49  case CombiningKind::XOR:
50  type = KindType::INT;
51  break;
52  case CombiningKind::ADD:
53  case CombiningKind::MUL:
54  type = isInt ? KindType::INT : KindType::FLOAT;
55  break;
56  }
57  bool isValidIntKind = (type == KindType::INT) && isInt;
58  bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt);
59  return (isValidIntKind || isValidFloatKind);
60 }
61 
62 namespace {
63 /// Convert vector.scan op into arith ops and vector.insert_strided_slice /
64 /// vector.extract_strided_slice.
65 ///
66 /// Example:
67 ///
68 /// ```
69 /// %0:2 = vector.scan <add>, %arg0, %arg1
70 /// {inclusive = true, reduction_dim = 1} :
71 /// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>)
72 /// ```
73 ///
74 /// is converted to:
75 ///
76 /// ```
77 /// %cst = arith.constant dense<0> : vector<2x3xi32>
78 /// %0 = vector.extract_strided_slice %arg0
79 /// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]}
80 /// : vector<2x3xi32> to vector<2x1xi32>
81 /// %1 = vector.insert_strided_slice %0, %cst
82 /// {offsets = [0, 0], strides = [1, 1]}
83 /// : vector<2x1xi32> into vector<2x3xi32>
84 /// %2 = vector.extract_strided_slice %arg0
85 /// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]}
86 /// : vector<2x3xi32> to vector<2x1xi32>
87 /// %3 = arith.muli %0, %2 : vector<2x1xi32>
88 /// %4 = vector.insert_strided_slice %3, %1
89 /// {offsets = [0, 1], strides = [1, 1]}
90 /// : vector<2x1xi32> into vector<2x3xi32>
91 /// %5 = vector.extract_strided_slice %arg0
92 /// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]}
93 /// : vector<2x3xi32> to vector<2x1xi32>
94 /// %6 = arith.muli %3, %5 : vector<2x1xi32>
95 /// %7 = vector.insert_strided_slice %6, %4
96 /// {offsets = [0, 2], strides = [1, 1]}
97 /// : vector<2x1xi32> into vector<2x3xi32>
98 /// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32>
99 /// return %7, %8 : vector<2x3xi32>, vector<2xi32>
100 /// ```
101 struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
103 
104  LogicalResult matchAndRewrite(vector::ScanOp scanOp,
105  PatternRewriter &rewriter) const override {
106  auto loc = scanOp.getLoc();
107  VectorType destType = scanOp.getDestType();
108  ArrayRef<int64_t> destShape = destType.getShape();
109  auto elType = destType.getElementType();
110  bool isInt = elType.isIntOrIndex();
111  if (!isValidKind(isInt, scanOp.getKind()))
112  return failure();
113 
114  VectorType resType = VectorType::get(destShape, elType);
115  Value result = arith::ConstantOp::create(rewriter, loc, resType,
116  rewriter.getZeroAttr(resType));
117  int64_t reductionDim = scanOp.getReductionDim();
118  bool inclusive = scanOp.getInclusive();
119  int64_t destRank = destType.getRank();
120  VectorType initialValueType = scanOp.getInitialValueType();
121  int64_t initialValueRank = initialValueType.getRank();
122 
123  SmallVector<int64_t> reductionShape(destShape);
124  reductionShape[reductionDim] = 1;
125  VectorType reductionType = VectorType::get(reductionShape, elType);
126  SmallVector<int64_t> offsets(destRank, 0);
127  SmallVector<int64_t> strides(destRank, 1);
128  SmallVector<int64_t> sizes(destShape);
129  sizes[reductionDim] = 1;
130  ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes);
131  ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides);
132 
133  Value lastOutput, lastInput;
134  for (int i = 0; i < destShape[reductionDim]; i++) {
135  offsets[reductionDim] = i;
136  ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets);
137  Value input = vector::ExtractStridedSliceOp::create(
138  rewriter, loc, reductionType, scanOp.getSource(), scanOffsets,
139  scanSizes, scanStrides);
140  Value output;
141  if (i == 0) {
142  if (inclusive) {
143  output = input;
144  } else {
145  if (initialValueRank == 0) {
146  // ShapeCastOp cannot handle 0-D vectors
147  output = vector::BroadcastOp::create(rewriter, loc, input.getType(),
148  scanOp.getInitialValue());
149  } else {
150  output = vector::ShapeCastOp::create(rewriter, loc, input.getType(),
151  scanOp.getInitialValue());
152  }
153  }
154  } else {
155  Value y = inclusive ? input : lastInput;
156  output = vector::makeArithReduction(rewriter, loc, scanOp.getKind(),
157  lastOutput, y);
158  }
159  result = vector::InsertStridedSliceOp::create(rewriter, loc, output,
160  result, offsets, strides);
161  lastOutput = output;
162  lastInput = input;
163  }
164 
165  Value reduction;
166  if (initialValueRank == 0) {
167  Value v = vector::ExtractOp::create(rewriter, loc, lastOutput, 0);
168  reduction =
169  vector::BroadcastOp::create(rewriter, loc, initialValueType, v);
170  } else {
171  reduction = vector::ShapeCastOp::create(rewriter, loc, initialValueType,
172  lastOutput);
173  }
174 
175  rewriter.replaceOp(scanOp, {result, reduction});
176  return success();
177  }
178 };
179 } // namespace
180 
183  patterns.add<ScanToArithOps>(patterns.getContext(), benefit);
184 }
union mlir::linalg::@1227::ArityGroupAndKind::Kind kind
static bool isValidKind(bool isInt, vector::CombiningKind kind)
This function checks to see if the vector combining kind is consistent with the integer or float elem...
#define MINUI(lhs, rhs)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:319
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:276
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319