MLIR  19.0.0git
VectorDistribution.h
Go to the documentation of this file.
1 //===- VectorDistribution.h - Vector distribution patterns --*- C++------*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
10 #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
11 
13 
14 namespace mlir {
15 class RewritePatternSet;
16 namespace vector {
17 
19  /// Lamdba function to let users allocate memory needed for the lowering of
20  /// WarpExecuteOnLane0Op.
21  /// The function needs to return an allocation that the lowering can use as
22  /// temporary memory. The allocation needs to match the shape of the type (the
23  /// type may be VectorType or a scalar) and be availble for the current warp.
24  /// If there are several warps running in parallel the allocation needs to be
25  /// split so that each warp has its own allocation.
27  std::function<Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)>;
29 
30  /// Lamdba function to let user emit operation to syncronize all the thread
31  /// within a warp. After this operation all the threads can see any memory
32  /// written before the operation.
34  std::function<void(Location, OpBuilder &, WarpExecuteOnLane0Op)>;
36 };
37 
39  RewritePatternSet &patterns,
41  PatternBenefit benefit = 1);
42 
43 using DistributionMapFn = std::function<AffineMap(Value)>;
44 
45 /// Distribute transfer_write ops based on the affine map returned by
46 /// `distributionMapFn`. Writes of size more than `maxNumElementToExtract`
47 /// will not be distributed (it should be less than the warp size).
48 ///
49 /// Example:
50 /// ```
51 /// %0 = vector.warp_execute_on_lane_0(%id){
52 /// ...
53 /// vector.transfer_write %v, %A[%c0] : vector<32xf32>, memref<128xf32>
54 /// vector.yield
55 /// }
56 /// ```
57 /// To
58 /// ```
59 /// %r:3 = vector.warp_execute_on_lane_0(%id) -> (vector<1xf32>) {
60 /// ...
61 /// vector.yield %v : vector<32xf32>
62 /// }
63 /// vector.transfer_write %v, %A[%id] : vector<1xf32>, memref<128xf32>
64 ///
65 /// When applied at the same time as the vector propagation patterns,
66 /// distribution of `vector.transfer_write` is expected to have the highest
67 /// priority (pattern benefit). By making propagation of `vector.transfer_read`
68 /// be the lowest priority pattern, it will be the last vector operation to
69 /// distribute, meaning writes should propagate first.
70 void populateDistributeTransferWriteOpPatterns(
71  RewritePatternSet &patterns, const DistributionMapFn &distributionMapFn,
72  unsigned maxNumElementsToExtract, PatternBenefit benefit = 2);
73 
74 /// Move scalar operations with no dependency on the warp op outside of the
75 /// region.
76 void moveScalarUniformCode(WarpExecuteOnLane0Op op);
77 
78 /// Lambda signature to compute a warp shuffle of a given value of a given lane
79 /// within a given warp size.
80 using WarpShuffleFromIdxFn =
81  std::function<Value(Location, OpBuilder &b, Value, Value, int64_t)>;
82 
83 /// Collect patterns to propagate warp distribution. `distributionMapFn` is used
84 /// to decide how a value should be distributed when this cannot be inferred
85 /// from its uses.
86 ///
87 /// The separate control over the `vector.transfer_read` op pattern benefit
88 /// is given to ensure the order of reads/writes before and after distribution
89 /// is consistent. As noted above, writes are expected to have the highest
90 /// priority for distribution, but are only ever distributed if adjacent to the
91 /// yield. By making reads the lowest priority pattern, it will be the last
92 /// vector operation to distribute, meaning writes should propagate first. This
93 /// is relatively brittle when ops fail to distribute, but that is a limitation
94 /// of these propagation patterns when there is a dependency not modeled by SSA.
95 void populatePropagateWarpVectorDistributionPatterns(
96  RewritePatternSet &pattern, const DistributionMapFn &distributionMapFn,
97  const WarpShuffleFromIdxFn &warpShuffleFromIdxFn,
98  PatternBenefit benefit = 1, PatternBenefit readBenefit = 0);
99 
100 /// Lambda signature to compute a reduction of a distributed value for the given
101 /// reduction kind and size.
102 using DistributedReductionFn =
103  std::function<Value(Location, OpBuilder &, Value, CombiningKind, uint32_t)>;
104 
105 /// Collect patterns to distribute vector reduction ops using given lamdba to
106 /// distribute reduction op.
107 void populateDistributeReduction(
108  RewritePatternSet &pattern,
109  const DistributedReductionFn &distributedReductionFn,
110  PatternBenefit benefit = 1);
111 
112 } // namespace vector
113 } // namespace mlir
114 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORDISTRIBUTION_H_
static llvm::ManagedStatic< PassManagerOptions > options
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
void populateWarpExecuteOnLane0OpToScfForPattern(RewritePatternSet &patterns, const WarpExecuteOnLane0LoweringOptions &options, PatternBenefit benefit=1)
std::function< AffineMap(Value)> DistributionMapFn
Include the generated interface declarations.
std::function< Value(Location, OpBuilder &, WarpExecuteOnLane0Op, Type)> WarpAllocationFn
Lamdba function to let users allocate memory needed for the lowering of WarpExecuteOnLane0Op.
std::function< void(Location, OpBuilder &, WarpExecuteOnLane0Op)> WarpSyncronizationFn
Lamdba function to let user emit operation to syncronize all the thread within a warp.