MLIR  18.0.0git
LowerVectorGather.cpp
Go to the documentation of this file.
1 //===- LowerVectorGather.cpp - Lower 'vector.gather' operation ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements target-independent rewrites and utilities to lower the
10 // 'vector.gather' operation.
11 //
12 //===----------------------------------------------------------------------===//
13 
27 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Location.h"
30 #include "mlir/IR/Matchers.h"
31 #include "mlir/IR/PatternMatch.h"
32 #include "mlir/IR/TypeUtilities.h"
35 
36 #define DEBUG_TYPE "vector-broadcast-lowering"
37 
38 using namespace mlir;
39 using namespace mlir::vector;
40 
41 namespace {
42 /// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
43 /// outermost dimension. For example:
44 /// ```
45 /// %g = vector.gather %base[%c0][%v], %mask, %pass_thru :
46 /// ... into vector<2x3xf32>
47 ///
48 /// ==>
49 ///
50 /// %0 = arith.constant dense<0.0> : vector<2x3xf32>
51 /// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ...
52 /// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32>
53 /// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ...
54 /// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32>
55 /// ```
56 ///
57 /// When applied exhaustively, this will produce a sequence of 1-d gather ops.
58 struct FlattenGather : OpRewritePattern<vector::GatherOp> {
60 
61  LogicalResult matchAndRewrite(vector::GatherOp op,
62  PatternRewriter &rewriter) const override {
63  VectorType resultTy = op.getType();
64  if (resultTy.getRank() < 2)
65  return rewriter.notifyMatchFailure(op, "already flat");
66 
67  Location loc = op.getLoc();
68  Value indexVec = op.getIndexVec();
69  Value maskVec = op.getMask();
70  Value passThruVec = op.getPassThru();
71 
72  Value result = rewriter.create<arith::ConstantOp>(
73  loc, resultTy, rewriter.getZeroAttr(resultTy));
74 
75  Type subTy = VectorType::get(resultTy.getShape().drop_front(),
76  resultTy.getElementType());
77 
78  for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
79  int64_t thisIdx[1] = {i};
80 
81  Value indexSubVec =
82  rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
83  Value maskSubVec =
84  rewriter.create<vector::ExtractOp>(loc, maskVec, thisIdx);
85  Value passThruSubVec =
86  rewriter.create<vector::ExtractOp>(loc, passThruVec, thisIdx);
87  Value subGather = rewriter.create<vector::GatherOp>(
88  loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec,
89  passThruSubVec);
90  result =
91  rewriter.create<vector::InsertOp>(loc, subGather, result, thisIdx);
92  }
93 
94  rewriter.replaceOp(op, result);
95  return success();
96  }
97 };
98 
99 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
100 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
101 /// loads/extracts are made conditional using `scf.if` ops.
102 struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
104 
105  LogicalResult matchAndRewrite(vector::GatherOp op,
106  PatternRewriter &rewriter) const override {
107  VectorType resultTy = op.getType();
108  if (resultTy.getRank() != 1)
109  return rewriter.notifyMatchFailure(op, "unsupported rank");
110 
111  Location loc = op.getLoc();
112  Type elemTy = resultTy.getElementType();
113  // Vector type with a single element. Used to generate `vector.loads`.
114  VectorType elemVecTy = VectorType::get({1}, elemTy);
115 
116  Value condMask = op.getMask();
117  Value base = op.getBase();
118  Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
119  loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
120  op.getIndexVec());
121  auto baseOffsets = llvm::to_vector(op.getIndices());
122  Value lastBaseOffset = baseOffsets.back();
123 
124  Value result = op.getPassThru();
125 
126  // Emit a conditional access for each vector element.
127  for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) {
128  int64_t thisIdx[1] = {i};
129  Value condition =
130  rewriter.create<vector::ExtractOp>(loc, condMask, thisIdx);
131  Value index = rewriter.create<vector::ExtractOp>(loc, indexVec, thisIdx);
132  baseOffsets.back() =
133  rewriter.createOrFold<arith::AddIOp>(loc, lastBaseOffset, index);
134 
135  auto loadBuilder = [&](OpBuilder &b, Location loc) {
136  Value extracted;
137  if (isa<MemRefType>(base.getType())) {
138  // `vector.load` does not support scalar result; emit a vector load
139  // and extract the single result instead.
140  Value load =
141  b.create<vector::LoadOp>(loc, elemVecTy, base, baseOffsets);
142  int64_t zeroIdx[1] = {0};
143  extracted = b.create<vector::ExtractOp>(loc, load, zeroIdx);
144  } else {
145  extracted = b.create<tensor::ExtractOp>(loc, base, baseOffsets);
146  }
147 
148  Value newResult =
149  b.create<vector::InsertOp>(loc, extracted, result, thisIdx);
150  b.create<scf::YieldOp>(loc, newResult);
151  };
152  auto passThruBuilder = [result](OpBuilder &b, Location loc) {
153  b.create<scf::YieldOp>(loc, result);
154  };
155 
156  result =
157  rewriter
158  .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder,
159  /*elseBuilder=*/passThruBuilder)
160  .getResult(0);
161  }
162 
163  rewriter.replaceOp(op, result);
164  return success();
165  }
166 };
167 } // namespace
168 
170  RewritePatternSet &patterns, PatternBenefit benefit) {
171  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
172  benefit);
173 }
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:505
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:686
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:660
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Populate the pattern set with the following patterns:
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:361