MLIR  16.0.0git
Tensor.h
Go to the documentation of this file.
1 //===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10 #define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11 
13 #include "mlir/IR/BuiltinTypes.h"
14 #include "mlir/IR/Dialect.h"
15 #include "mlir/IR/OpDefinition.h"
24 
25 //===----------------------------------------------------------------------===//
26 // Tensor Dialect Helpers
27 //===----------------------------------------------------------------------===//
28 
29 namespace mlir {
30 
31 /// Return the list of Range (i.e. offset, size, stride). Each Range
32 /// entry contains either the dynamic value or a ConstantIndexOp constructed
33 /// with `b` at location `loc`.
34 SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
35  OpBuilder &b, Location loc);
36 
37 } // namespace mlir
38 
39 //===----------------------------------------------------------------------===//
40 // Tensor Dialect
41 //===----------------------------------------------------------------------===//
42 
43 #include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
44 
45 //===----------------------------------------------------------------------===//
46 // Tensor Dialect Operations
47 //===----------------------------------------------------------------------===//
48 
49 #define GET_OP_CLASSES
50 #include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
51 
52 //===----------------------------------------------------------------------===//
53 // Tensor Dialect Helpers
54 //===----------------------------------------------------------------------===//
55 
56 namespace mlir {
57 namespace tensor {
58 
59 /// Returns true if `target` is a ranked tensor type that preserves static
60 /// information available in the `source` ranked tensor type.
61 bool preservesStaticInformation(Type source, Type target);
62 
63 /// Determines whether tensor::CastOp casts to a more dynamic version of the
64 /// source tensor. This is useful to fold a tensor.cast into a consuming op and
65 /// implement canonicalization patterns for ops in different dialects that may
66 /// consume the results of tensor.cast operations. Such foldable tensor.cast
67 /// operations are typically inserted as `extract_slice` ops and are
68 /// canonicalized, to preserve the type compatibility of their uses.
69 ///
70 /// Returns true when all conditions are met:
71 /// 1. source and result are ranked tensors with same element type and rank.
72 /// 2. the tensor type has more static information than the result
73 ///
74 /// Example:
75 /// ```mlir
76 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
77 /// %2 = consumer %1 ... : tensor<?x?xf32> ...
78 /// ```
79 ///
80 /// folds into:
81 ///
82 /// ```mlir
83 /// %2 = consumer %0 ... : tensor<8x16xf32> ...
84 /// ```
85 bool canFoldIntoConsumerOp(CastOp castOp);
86 
87 /// Determines whether the tensor::CastOp casts to a more static version of the
88 /// source tensor. This is useful to fold into a producing op and implement
89 /// canonicaliation patterns with the `tensor.cast` op as the root, but producer
90 /// being from different dialects. Returns true when all conditions are met:
91 /// 1. source and result and ranked tensors with same element type and rank.
92 /// 2. the result type has more static information than the source.
93 ///
94 /// Example:
95 /// ```mlir
96 /// %1 = producer ... : tensor<?x?xf32>
97 /// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
98 /// ```
99 ///
100 /// can be canonicalized to :
101 ///
102 /// ```mlir
103 /// %2 = producer ... : tensor<8x16xf32>
104 /// ```
105 /// Not all ops might be canonicalizable this way, but for those that can be,
106 /// this method provides a check that it is worth doing the canonicalization.
107 bool canFoldIntoProducerOp(CastOp castOp);
108 
109 /// Performs folding of any operand of `op` if it comes from a tensor::CastOp
110 /// that can be folded.
111 LogicalResult foldTensorCast(Operation *op);
112 
113 /// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
114 /// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
115 /// to that of `targetType`.
116 Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
117  Value tensor,
118  RankedTensorType targetType);
119 
120 /// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
121 /// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
122 /// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
123 /// at the canonical [0 .. 0] position.
124 Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
125  Value tensor, Value dest);
126 
127 /// Function to control the folding of constant and extract slice
128 using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
129 
130 /// Patterns to fold the extract slice op with its constant operand
132  RewritePatternSet &patterns,
133  const ControlConstantExtractSliceFusionFn &controlFn =
134  [](ExtractSliceOp op) {
135  // Disable by default because the folding can generate a large
136  // constant tensor, which would affect the compile time and storage.
137  return false;
138  });
139 
140 } // namespace tensor
141 } // namespace mlir
142 
143 #endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
Include the generated interface declarations.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded...
Definition: TensorOps.cpp:132
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
Definition: TensorOps.cpp:45
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition: Tensor.h:128
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
Definition: TensorOps.cpp:1464
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Definition: TensorOps.cpp:93
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
Definition: MemRefOps.cpp:2404
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
Definition: TensorOps.cpp:1376
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor...
Definition: TensorOps.cpp:123
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Definition: TensorOps.cpp:1815