MLIR  15.0.0git
VectorOps.h
Go to the documentation of this file.
1 //===- VectorOps.h - MLIR Vector Dialect Operations -------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the Vector dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
14 #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
15 
16 #include "mlir/IR/AffineMap.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/OpDefinition.h"
21 #include "mlir/IR/PatternMatch.h"
27 #include "llvm/ADT/StringExtras.h"
28 
29 // Pull in all enum type definitions and utility function declarations.
30 #include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
31 
32 namespace mlir {
33 class MLIRContext;
34 class RewritePatternSet;
35 
36 namespace arith {
37 enum class AtomicRMWKind : uint64_t;
38 } // namespace arith
39 
40 namespace vector {
41 class TransferReadOp;
42 class TransferWriteOp;
43 class VectorDialect;
44 
45 namespace detail {
46 struct BitmaskEnumStorage;
47 } // namespace detail
48 
49 /// Return whether `srcType` can be broadcast to `dstVectorType` under the
50 /// semantics of the `vector.broadcast` op.
52  Success = 0,
53  SourceRankHigher = 1,
56 };
58 isBroadcastableTo(Type srcType, VectorType dstVectorType,
59  std::pair<int, int> *mismatchingDims = nullptr);
60 
61 /// Collect a set of vector-to-vector canonicalization patterns.
63  RewritePatternSet &patterns);
64 
65 /// Collect a set of vector.shape_cast folding patterns.
67 
68 /// Collect a set of leading one dimension removal patterns.
69 ///
70 /// These patterns insert vector.shape_cast to remove leading one dimensions
71 /// to expose more canonical forms of read/write/insert/extract operations.
72 /// With them, there are more chances that we can cancel out extract-insert
73 /// pairs or forward write-read pairs.
75 
76 /// Collect a set of one dimension removal patterns.
77 ///
78 /// These patterns insert rank-reducing memref.subview ops to remove one
79 /// dimensions. With them, there are more chances that we can avoid
80 /// potentially exensive vector.shape_cast operations.
82 
83 /// Collect a set of patterns to flatten n-D vector transfers on contiguous
84 /// memref.
85 ///
86 /// These patterns insert memref.collapse_shape + vector.shape_cast patterns
87 /// to transform multiple small n-D transfers into a larger 1-D transfer where
88 /// the memref contiguity properties allow it.
90 
91 /// Collect a set of patterns that bubble up/down bitcast ops.
92 ///
93 /// These patterns move vector.bitcast ops to be before insert ops or after
94 /// extract ops where suitable. With them, bitcast will happen on smaller
95 /// vectors and there are more chances to share extract/insert ops.
97 
98 /// Collect a set of transfer read/write lowering patterns.
99 ///
100 /// These patterns lower transfer ops to simpler ops like `vector.load`,
101 /// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank
102 /// of a most `maxTransferRank` are lowered. This is useful when combined with
103 /// VectorToSCF, which reduces the rank of vector transfer ops.
105  RewritePatternSet &patterns,
106  llvm::Optional<unsigned> maxTransferRank = llvm::None);
107 
108 /// These patterns materialize masks for various vector ops such as transfers.
110  bool force32BitVectorIndices);
111 
112 /// Collect a set of patterns to propagate insert_map/extract_map in the ssa
113 /// chain.
115 
116 /// An attribute that specifies the combining function for `vector.contract`,
117 /// and `vector.reduction`.
119  : public Attribute::AttrBase<CombiningKindAttr, Attribute,
120  detail::BitmaskEnumStorage> {
121 public:
122  using Base::Base;
123 
124  static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
125 
126  CombiningKind getKind() const;
127 
128  void print(AsmPrinter &p) const;
129  static Attribute parse(AsmParser &parser, Type type);
130 };
131 
132 /// Collects patterns to progressively lower vector.broadcast ops on high-D
133 /// vectors to low-D vector ops.
135 
136 /// Collects patterns to progressively lower vector mask ops into elementary
137 /// selection and insertion ops.
139 
140 /// Collects patterns to progressively lower vector.shape_cast ops on high-D
141 /// vectors into 1-D/2-D vector ops by generating data movement extract/insert
142 /// ops.
144 
145 /// Returns the integer type required for subscripts in the vector dialect.
146 IntegerType getVectorSubscriptType(Builder &builder);
147 
148 /// Returns an integer array attribute containing the given values using
149 /// the integer type required for subscripts in the vector dialect.
150 ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef<int64_t> values);
151 
152 /// Returns the value obtained by reducing the vector into a scalar using the
153 /// operation kind associated with a binary AtomicRMWKind op.
154 Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder,
155  Location loc, Value vector);
156 
157 /// Return true if the last dimension of the MemRefType has unit stride. Also
158 /// return true for memrefs with no strides.
159 bool isLastMemrefDimUnitStride(MemRefType type);
160 
161 /// Build the default minor identity map suitable for a vector transfer. This
162 /// also handles the case memref<... x vector<...>> -> vector<...> in which the
163 /// rank of the identity map must take the vector element type into account.
164 AffineMap getTransferMinorIdentityMap(ShapedType shapedType,
165  VectorType vectorType);
166 
167 /// Return true if the transfer_write fully writes the data accessed by the
168 /// transfer_read.
169 bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read);
170 
171 /// Return true if the write op fully over-write the priorWrite transfer_write
172 /// op.
173 bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite);
174 
175 /// Same behavior as `isDisjointTransferSet` but doesn't require the operations
176 /// to have the same tensor/memref. This allows comparing operations accessing
177 /// different tensors.
178 bool isDisjointTransferIndices(VectorTransferOpInterface transferA,
179  VectorTransferOpInterface transferB);
180 
181 /// Return true if we can prove that the transfer operations access disjoint
182 /// memory.
183 bool isDisjointTransferSet(VectorTransferOpInterface transferA,
184  VectorTransferOpInterface transferB);
185 } // namespace vector
186 } // namespace mlir
187 
188 #define GET_OP_CLASSES
189 #include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
190 #include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"
191 
192 #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
Include the generated interface declarations.
Value getVectorReductionOp(arith::AtomicRMWKind op, OpBuilder &builder, Location loc, Value vector)
Returns the value obtained by reducing the vector into a scalar using the operation kind associated w...
Definition: VectorOps.cpp:449
void populateCastAwayVectorLeadingOneDimPatterns(RewritePatternSet &patterns)
Collect a set of leading one dimension removal patterns.
bool isLastMemrefDimUnitStride(MemRefType type)
Return true if the last dimension of the MemRefType has unit stride.
Definition: VectorOps.cpp:115
void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns)
Collect a set of patterns that bubble up/down bitcast ops.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
BroadcastableToResult isBroadcastableTo(Type srcType, VectorType dstVectorType, std::pair< int, int > *mismatchingDims=nullptr)
Definition: VectorOps.cpp:1692
BroadcastableToResult
Return whether srcType can be broadcast to dstVectorType under the semantics of the vector...
Definition: VectorOps.h:51
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns)
Collect a set of vector-to-vector canonicalization patterns.
void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices)
These patterns materialize masks for various vector ops such as transfers.
A multi-dimensional affine map Affine map&#39;s are immutable like Type&#39;s, and they are uniqued...
Definition: AffineMap.h:41
bool checkSameValueWAW(TransferWriteOp write, TransferWriteOp priorWrite)
Return true if the write op fully over-write the priorWrite transfer_write op.
This base class exposes generic asm parser hooks, usable across the various derived parsers...
static void print(ArrayType type, DialectAsmPrinter &os)
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector mask ops into elementary selection and insertion ops...
AffineMap getTransferMinorIdentityMap(ShapedType shapedType, VectorType vectorType)
Build the default minor identity map suitable for a vector transfer.
Definition: VectorOps.cpp:122
void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.shape_cast ops on high-D vectors into 1-D/2-D vector ...
An attribute that specifies the combining function for vector.contract, and vector.reduction.
Definition: VectorOps.h:118
void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to propagate insert_map/extract_map in the ssa chain.
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
void populateShapeCastFoldingPatterns(RewritePatternSet &patterns)
Collect a set of vector.shape_cast folding patterns.
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
ArrayAttr getVectorSubscriptAttr(Builder &b, ArrayRef< int64_t > values)
Returns an integer array attribute containing the given values using the integer type required for su...
Definition: VectorOps.cpp:326
Utility class for implementing users of storage classes uniqued by a StorageUniquer.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns, llvm::Optional< unsigned > maxTransferRank=llvm::None)
Collect a set of transfer read/write lowering patterns.
bool isDisjointTransferIndices(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Same behavior as isDisjointTransferSet but doesn&#39;t require the operations to have the same tensor/mem...
Definition: VectorOps.cpp:157
IntegerType getVectorSubscriptType(Builder &builder)
Returns the integer type required for subscripts in the vector dialect.
Definition: VectorOps.cpp:322
This base class exposes generic asm printer hooks, usable across the various derived printers...
void populateVectorTransferDropUnitDimsPatterns(RewritePatternSet &patterns)
Collect a set of one dimension removal patterns.
bool checkSameValueRAW(TransferWriteOp defWrite, TransferReadOp read)
Return true if the transfer_write fully writes the data accessed by the transfer_read.
void populateFlattenVectorTransferPatterns(RewritePatternSet &patterns)
Collect a set of patterns to flatten n-D vector transfers on contiguous memref.
This class helps build Operations.
Definition: Builders.h:177
void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns)
Collects patterns to progressively lower vector.broadcast ops on high-D vectors to low-D vector ops...
bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB)
Return true if we can prove that the transfer operations access disjoint memory.
Definition: VectorOps.cpp:189