MLIR  21.0.0git
Tensor.h
Go to the documentation of this file.
1 //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11 
14 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Dialect.h"
16 #include "mlir/IR/OpDefinition.h"
28 
29 //===----------------------------------------------------------------------===//
30 // Tensor Dialect Helpers
31 //===----------------------------------------------------------------------===//
32 
33 namespace mlir {
34 
35 /// Return the list of Range (i.e. offset, size, stride). Each Range
36 /// entry contains either the dynamic value or a ConstantIndexOp constructed
37 /// with `b` at location `loc`.
38 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
39  OpBuilder &b, Location loc);
40 
41 } // namespace mlir
42 
43 //===----------------------------------------------------------------------===//
44 // Tensor Dialect
45 //===----------------------------------------------------------------------===//
46 
47 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
48 
49 //===----------------------------------------------------------------------===//
50 // Tensor Dialect Operations
51 //===----------------------------------------------------------------------===//
52 
53 #define GET_OP_CLASSES
54 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
55 
56 //===----------------------------------------------------------------------===//
57 // Tensor Dialect Helpers
58 //===----------------------------------------------------------------------===//
59 
60 namespace mlir {
61 namespace tensor {
62 
63 /// Returns true if `target` is a ranked tensor type that preserves static
64 /// information available in the `source` ranked tensor type.
65 bool preservesStaticInformation(Type source, Type target);
66 
67 /// Determines whether tensor::CastOp casts to a more dynamic version of the
68 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
69 /// implement canonicalization patterns for ops in different dialects that may
70 /// consume the results of tensor.cast operations. Such foldable tensor.cast
71 /// operations are typically inserted as `extract_slice` ops and are
72 /// canonicalized, to preserve the type compatibility of their uses.
73 ///
74 /// Returns true when all conditions are met:
75 /// 1. source and result are ranked tensors with same element type and rank.
76 /// 2. the tensor type has more static information than the result
77 ///
78 /// Example:
79 /// ```mlir
80 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
81 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
82 /// ```
83 ///
84 /// folds into:
85 ///
86 /// ```mlir
87 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
88 /// ```
89 bool canFoldIntoConsumerOp(CastOp castOp);
90 
91 /// Determines whether the tensor::CastOp casts to a more static version of the
92 /// source tensor. This is useful to fold into a producing op and implement
93 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
94 /// being from different dialects. Returns true when all conditions are met:
95 /// 1. source and result and ranked tensors with same element type and rank.
96 /// 2. the result type has more static information than the source.
97 ///
98 /// Example:
99 /// ```mlir
100 /// %1 = producer ... : tensor<?x?xf32>
101 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
102 /// ```
103 ///
104 /// can be canonicalized to :
105 ///
106 /// ```mlir
107 /// %2 = producer ... : tensor<8x16xf32>
108 /// ```
109 /// Not all ops might be canonicalizable this way, but for those that can be,
110 /// this method provides a check that it is worth doing the canonicalization.
111 bool canFoldIntoProducerOp(CastOp castOp);
112 
113 /// Return true if any of the operands of `op` is a CastOp that can be folded
114 /// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
115 /// `canFoldIntoProducerOp`.
116 bool hasFoldableTensorCastOperand(Operation *op);
117 
118 /// Assuming that `op` contains at least one operand that is a foldable CastOp
119 /// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
120 /// operands.
121 SmallVector<Value>
122 getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
123  SmallVector<Type> &newResTy);
124 
125 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
126 /// that can be folded.
127 LogicalResult foldTensorCast(Operation *op);
128 
129 /// Return the dimension of the given tensor value.
130 OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
131  int64_t dim);
132 
133 /// Return the dimensions of the given tensor value.
134 SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
135  Value value);
136 
137 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
138 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
139 /// to that of `targetType`.
140 Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
141  Value tensor,
142  RankedTensorType targetType);
143 
144 /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
145 /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
146 /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
147 /// at the canonical [0 .. 0] position.
148 Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
149  Value tensor, Value dest);
150 
151 /// This is a helper function for DestinationStyleOpInterface. If there is a
152 /// destination operand for the given OpResult, return that operand. Otherwise,
153 /// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
154 /// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
155 FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
156  OpResult opResult);
157 
158 /// This is a helper function for DestinationStyleOpInterface. Get or create
159 /// destinations for every tensor OpResult of the given op.
160 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
161  SmallVector<Value> &result);
162 
163 /// Tests if types are the same when ignoring encoding on ranked tensors.
164 bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
165 
166 /// Function to control the folding of constant and extract slice.
167 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
168 
169 /// Patterns to fold the extract slice op with its constant operand.
172  const ControlConstantExtractSliceFusionFn &controlFn =
173  [](ExtractSliceOp op) {
174  // Disable by default because the folding can generate a large
175  // constant tensor, which would affect the compile time and storage.
176  return false;
177  });
178 
179 } // namespace tensor
180 } // namespace mlir
181 
182 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:389
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:358
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:2501
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:351
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:367
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:321
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2993
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2588
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
Definition: TensorOps.cpp:127
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition: TensorOps.cpp:60
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:78
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:167
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:269
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:69
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:113
Include the generated interface declarations.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:3006
const FrozenRewritePatternSet & patterns