MLIR  19.0.0git
Tensor.h
Go to the documentation of this file.
1 //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
27 
28 //===----------------------------------------------------------------------===//
29 // Tensor Dialect Helpers
30 //===----------------------------------------------------------------------===//
31 
32 namespace mlir {
33 
34 /// Return the list of Range (i.e. offset, size, stride). Each Range
35 /// entry contains either the dynamic value or a ConstantIndexOp constructed
36 /// with `b` at location `loc`.
37 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
38  OpBuilder &b, Location loc);
39 
40 } // namespace mlir
41 
42 //===----------------------------------------------------------------------===//
43 // Tensor Dialect
44 //===----------------------------------------------------------------------===//
45 
46 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
47 
48 //===----------------------------------------------------------------------===//
49 // Tensor Dialect Operations
50 //===----------------------------------------------------------------------===//
51 
52 #define GET_OP_CLASSES
53 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
54 
55 //===----------------------------------------------------------------------===//
56 // Tensor Dialect Helpers
57 //===----------------------------------------------------------------------===//
58 
59 namespace mlir {
60 namespace tensor {
61 
62 /// Returns true if `target` is a ranked tensor type that preserves static
63 /// information available in the `source` ranked tensor type.
64 bool preservesStaticInformation(Type source, Type target);
65 
66 /// Determines whether tensor::CastOp casts to a more dynamic version of the
67 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
68 /// implement canonicalization patterns for ops in different dialects that may
69 /// consume the results of tensor.cast operations. Such foldable tensor.cast
70 /// operations are typically inserted as `extract_slice` ops and are
71 /// canonicalized, to preserve the type compatibility of their uses.
72 ///
73 /// Returns true when all conditions are met:
74 /// 1. source and result are ranked tensors with same element type and rank.
75 /// 2. the tensor type has more static information than the result
76 ///
77 /// Example:
78 /// ```mlir
79 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
80 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
81 /// ```
82 ///
83 /// folds into:
84 ///
85 /// ```mlir
86 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
87 /// ```
88 bool canFoldIntoConsumerOp(CastOp castOp);
89 
90 /// Determines whether the tensor::CastOp casts to a more static version of the
91 /// source tensor. This is useful to fold into a producing op and implement
92 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
93 /// being from different dialects. Returns true when all conditions are met:
94 /// 1. source and result and ranked tensors with same element type and rank.
95 /// 2. the result type has more static information than the source.
96 ///
97 /// Example:
98 /// ```mlir
99 /// %1 = producer ... : tensor<?x?xf32>
100 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
101 /// ```
102 ///
103 /// can be canonicalized to :
104 ///
105 /// ```mlir
106 /// %2 = producer ... : tensor<8x16xf32>
107 /// ```
108 /// Not all ops might be canonicalizable this way, but for those that can be,
109 /// this method provides a check that it is worth doing the canonicalization.
110 bool canFoldIntoProducerOp(CastOp castOp);
111 
112 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
113 /// that can be folded.
114 LogicalResult foldTensorCast(Operation *op);
115 
116 /// Return the dimension of the given tensor value.
117 OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
118  int64_t dim);
119 
120 /// Return the dimensions of the given tensor value.
121 SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
122  Value value);
123 
124 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
125 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
126 /// to that of `targetType`.
127 Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
128  Value tensor,
129  RankedTensorType targetType);
130 
131 /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
132 /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
133 /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
134 /// at the canonical [0 .. 0] position.
135 Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
136  Value tensor, Value dest);
137 
138 /// This is a helper function for DestinationStyleOpInterface. If there is a
139 /// destination operand for the given OpResult, return that operand. Otherwise,
140 /// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
141 /// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
142 FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
143  OpResult opResult);
144 
145 /// This is a helper function for DestinationStyleOpInterface. Get or create
146 /// destinations for every tensor OpResult of the given op.
147 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
148  SmallVector<Value> &result);
149 
150 /// Tests if types are the same when ignoring encoding on ranked tensors.
151 bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
152 
153 /// Function to control the folding of constant and extract slice.
154 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
155 
156 /// Patterns to fold the extract slice op with its constant operand.
158  RewritePatternSet &patterns,
159  const ControlConstantExtractSliceFusionFn &controlFn =
160  [](ExtractSliceOp op) {
161  // Disable by default because the folding can generate a large
162  // constant tensor, which would affect the compile time and storage.
163  return false;
164  });
165 
166 } // namespace tensor
167 } // namespace mlir
168 
169 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:354
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2384
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:345
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:315
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2851
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2471
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:119
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:51
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:70
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:154
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:263
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:61
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:105
Include the generated interface declarations.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3013