MLIR  17.0.0git
VectorOps.h
Go to the documentation of this file.
1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
15 
18 #include "mlir/IR/AffineMap.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/IR/OpDefinition.h"
23 #include "mlir/IR/PatternMatch.h"
30 #include "llvm/ADT/StringExtras.h"
31 
32 // Pull in all enum type definitions and utility function declarations.
33 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
34 
35 #define GET_ATTRDEF_CLASSES
36 #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
37 
38 namespace mlir {
39 class MLIRContext;
40 class RewritePatternSet;
41 
42 namespace arith {
43 enum class AtomicRMWKind : uint64_t;
44 } // namespace arith
45 
46 namespace vector {
47 class TransferReadOp;
48 class TransferWriteOp;
49 class VectorDialect;
50 
51 namespace detail {
52 struct BitmaskEnumStorage;
53 } // namespace detail
54 
55 /// Default callback to build a region with a 'vector.yield' terminator with no
56 /// arguments.
57 void buildTerminatedBody(OpBuilder &builder, Location loc);
58 
59 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
60 /// semantics of the `vector.broadcast` op.
62  Success = 0,
63  SourceRankHigher = 1,
66 };
68 isBroadcastableTo(Type srcType, VectorType dstVectorType,
69  std::pair<int, int> *mismatchingDims = nullptr);
70 
71 /// Collect a set of vector-to-vector canonicalization patterns.
73  PatternBenefit benefit = 1);
74 
75 /// Collect a set of vector.shape_cast folding patterns.
77  PatternBenefit benefit = 1);
78 
79 /// Collect a set of leading one dimension removal patterns.
80 ///
81 /// These patterns insert vector.shape_cast to remove leading one dimensions
82 /// to expose more canonical forms of read/write/insert/extract operations.
83 /// With them, there are more chances that we can cancel out extract-insert
84 /// pairs or forward write-read pairs.
86  PatternBenefit benefit = 1);
87 
88 /// Collect a set of one dimension removal patterns.
89 ///
90 /// These patterns insert rank-reducing memref.subview ops to remove one
91 /// dimensions. With them, there are more chances that we can avoid
92 /// potentially exensive vector.shape_cast operations.
94  PatternBenefit benefit = 1);
95 
96 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
97 /// memref.
98 ///
99 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
100 /// to transform multiple small n-D transfers into a larger 1-D transfer where
101 /// the memref contiguity properties allow it.
103  PatternBenefit benefit = 1);
104 
105 /// Collect a set of patterns that bubble up/down bitcast ops.
106 ///
107 /// These patterns move vector.bitcast ops to be before insert ops or after
108 /// extract ops where suitable. With them, bitcast will happen on smaller
109 /// vectors and there are more chances to share extract/insert ops.
111  PatternBenefit benefit = 1);
112 
113 /// These patterns materialize masks for various vector ops such as transfers.
115  bool force32BitVectorIndices,
116  PatternBenefit benefit = 1);
117 
118 /// Returns the integer type required for subscripts in the vector dialect.
119 IntegerType getVectorSubscriptType(Builder &builder);
120 
121 /// Returns an integer array attribute containing the given values using
122 /// the integer type required for subscripts in the vector dialect.
123 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
124 
125 /// Returns the value obtained by reducing the vector into a scalar using the
126 /// operation kind associated with a binary AtomicRMWKind op.
127 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
128  Location loc, Value vector);
129 
130 /// Return true if the last dimension of the MemRefType has unit stride. Also
131 /// return true for memrefs with no strides.
132 bool isLastMemrefDimUnitStride(MemRefType type);
133 
134 /// Build the default minor identity map suitable for a vector transfer. This
135 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
136 /// rank of the identity map must take the vector element type into account.
137 AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
138  VectorType vectorType);
139 
140 /// Return true if the transfer_write fully writes the data accessed by the
141 /// transfer_read.
142 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
143 
144 /// Return true if the write op fully over-write the priorWrite transfer_write
145 /// op.
146 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
147 
148 /// Same behavior as `isDisjointTransferSet` but doesn't require the operations
149 /// to have the same tensor/memref. This allows comparing operations accessing
150 /// different tensors.
151 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
152  VectorTransferOpInterface transferB);
153 
154 /// Return true if we can prove that the transfer operations access disjoint
155 /// memory.
156 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
157  VectorTransferOpInterface transferB);
158 
159 /// Return the result value of reducing two scalar/vector values with the
160 /// corresponding arith operation.
161 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
162  Value v1, Value acc, Value mask = Value());
163 
164 /// Returns true if `attr` has "parallel" iterator type semantics.
165 inline bool isParallelIterator(Attribute attr) {
166  return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::parallel;
167 }
168 
169 /// Returns true if `attr` has "reduction" iterator type semantics.
170 inline bool isReductionIterator(Attribute attr) {
171  return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
172 }
173 
174 //===----------------------------------------------------------------------===//
175 // Vector Masking Utilities
176 //===----------------------------------------------------------------------===//
177 
178 /// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
179 /// as masked operation.
180 void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
181 
182 /// Creates a vector.mask operation around a maskable operation. Returns the
183 /// vector.mask operation if the mask provided is valid. Otherwise, returns the
184 /// maskable operation itself.
185 Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask,
186  Value passthru = Value());
187 
188 /// Creates a vector select operation that picks values from `newValue` or
189 /// `passthru` for each result vector lane based on `mask`. This utility is used
190 /// to propagate the pass-thru value for masked-out or expeculatively executed
191 /// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
192 /// is set to poison. LLVM backends are usually able to match op + select
193 /// patterns and fold them into a native target instructions.
194 Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
195  Value passthru);
196 
197 } // namespace vector
198 } // namespace mlir
199 
200 #define GET_OP_CLASSES
201 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
202 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"
203 
204 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:43
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U cast() const
Definition: Attributes.h:176
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:202
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:290
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:170
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback to build a region with a 'vector.yield' terminator with no arguments.
Definition: VectorOps.cpp:107
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn't require the operations to have the same tensor/mem...
Definition: VectorOps.cpp:177
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:142
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector.shape_cast folding patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
Operation * maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, Value passthru=Value())
Creates a vector.mask operation around a maskable operation.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:135
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of one dimension removal patterns.
Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, Value passthru)
Creates a vector select operation that picks values from newValue or passthru for each result vector ...
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns that bubble up/down bitcast ops.
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1949
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:165
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit=1)
These patterns materialize masks for various vector ops such as transfers.
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
Definition: VectorOps.cpp:209
void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp)
Create the vector.yield-ended region of a vector.mask op with maskableOp as masked operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, Value mask=Value())
Return the result value of reducing two scalar/vector values with the corresponding arith operation.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:489
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector....
Definition: VectorOps.h:61
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns, PatternBenefit benefit=1)
Collect a set of leading one dimension removal patterns.
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:286
This header declares functions that assit transformations in the MemRef dialect.