MLIR  17.0.0git
Tensor.h
Go to the documentation of this file.
1 //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/OpDefinition.h"
26 
27 //===----------------------------------------------------------------------===//
28 // Tensor Dialect Helpers
29 //===----------------------------------------------------------------------===//
30 
31 namespace mlir {
32 
33 /// Return the list of Range (i.e. offset, size, stride). Each Range
34 /// entry contains either the dynamic value or a ConstantIndexOp constructed
35 /// with `b` at location `loc`.
36 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
37  OpBuilder &b, Location loc);
38 
39 } // namespace mlir
40 
41 //===----------------------------------------------------------------------===//
42 // Tensor Dialect
43 //===----------------------------------------------------------------------===//
44 
45 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
46 
47 //===----------------------------------------------------------------------===//
48 // Tensor Dialect Operations
49 //===----------------------------------------------------------------------===//
50 
51 #define GET_OP_CLASSES
52 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
53 
54 //===----------------------------------------------------------------------===//
55 // Tensor Dialect Helpers
56 //===----------------------------------------------------------------------===//
57 
58 namespace mlir {
59 namespace tensor {
60 
61 /// Returns true if `target` is a ranked tensor type that preserves static
62 /// information available in the `source` ranked tensor type.
63 bool preservesStaticInformation(Type source, Type target);
64 
65 /// Determines whether tensor::CastOp casts to a more dynamic version of the
66 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
67 /// implement canonicalization patterns for ops in different dialects that may
68 /// consume the results of tensor.cast operations. Such foldable tensor.cast
69 /// operations are typically inserted as `extract_slice` ops and are
70 /// canonicalized, to preserve the type compatibility of their uses.
71 ///
72 /// Returns true when all conditions are met:
73 /// 1. source and result are ranked tensors with same element type and rank.
74 /// 2. the tensor type has more static information than the result
75 ///
76 /// Example:
77 /// ```mlir
78 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
79 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
80 /// ```
81 ///
82 /// folds into:
83 ///
84 /// ```mlir
85 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
86 /// ```
87 bool canFoldIntoConsumerOp(CastOp castOp);
88 
89 /// Determines whether the tensor::CastOp casts to a more static version of the
90 /// source tensor. This is useful to fold into a producing op and implement
91 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
92 /// being from different dialects. Returns true when all conditions are met:
93 /// 1. source and result and ranked tensors with same element type and rank.
94 /// 2. the result type has more static information than the source.
95 ///
96 /// Example:
97 /// ```mlir
98 /// %1 = producer ... : tensor<?x?xf32>
99 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
100 /// ```
101 ///
102 /// can be canonicalized to :
103 ///
104 /// ```mlir
105 /// %2 = producer ... : tensor<8x16xf32>
106 /// ```
107 /// Not all ops might be canonicalizable this way, but for those that can be,
108 /// this method provides a check that it is worth doing the canonicalization.
109 bool canFoldIntoProducerOp(CastOp castOp);
110 
111 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
112 /// that can be folded.
113 LogicalResult foldTensorCast(Operation *op);
114 
115 /// Return the dimensions of the given tensor value.
116 SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
117  Value value);
118 
119 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
120 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
121 /// to that of `targetType`.
122 Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
123  Value tensor,
124  RankedTensorType targetType);
125 
126 /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
127 /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
128 /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
129 /// at the canonical [0 .. 0] position.
130 Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
131  Value tensor, Value dest);
132 
133 /// This is a helper function for DestinationStyleOpInterface. If there is a
134 /// destination operand for the given OpResult, return that operand. Otherwise,
135 /// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
136 /// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
137 FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
138  OpResult opResult);
139 
140 /// This is a helper function for DestinationStyleOpInterface. Get or create
141 /// destinations for every tensor OpResult of the given op.
142 LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
143  SmallVector<Value> &result);
144 
145 /// Function to control the folding of constant and extract slice.
146 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
147 
148 /// Patterns to fold the extract slice op with its constant operand.
150  RewritePatternSet &patterns,
151  const ControlConstantExtractSliceFusionFn &controlFn =
152  [](ExtractSliceOp op) {
153  // Disable by default because the folding can generate a large
154  // constant tensor, which would affect the compile time and storage.
155  return false;
156  });
157 
158 /// Patterns to simplify tensor.pack.
160 
161 } // namespace tensor
162 } // namespace mlir
163 
164 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
Definition: TensorOps.cpp:214
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:1982
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:205
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:175
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:2421
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:2070
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:64
void populateSimplifyTensorPack(RewritePatternSet &patterns)
Patterns to simplify tensor.pack.
Definition: TensorOps.cpp:3057
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:146
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:127
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:49
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Definition: TensorOps.cpp:103
Include the generated interface declarations.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2853