MLIR 22.0.0git
Tensor.h
Go to the documentation of this file.
1//===- Tensor.h - Tensor dialect --------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_TENSOR_IR_TENSOR_H_
10#define MLIR_DIALECT_TENSOR_IR_TENSOR_H_
11
15#include "mlir/IR/Dialect.h"
28
29//===----------------------------------------------------------------------===//
30// Tensor Dialect Helpers
31//===----------------------------------------------------------------------===//
32
33namespace mlir {
34
35/// Return the list of Range (i.e. offset, size, stride). Each Range
36/// entry contains either the dynamic value or a ConstantIndexOp constructed
37/// with `b` at location `loc`.
38SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
39 OpBuilder &b, Location loc);
40
41} // namespace mlir
42
43//===----------------------------------------------------------------------===//
44// Tensor Dialect
45//===----------------------------------------------------------------------===//
46
47#include "mlir/Dialect/Tensor/IR/TensorOpsDialect.h.inc"
48
49//===----------------------------------------------------------------------===//
50// Tensor Dialect Operations
51//===----------------------------------------------------------------------===//
52
53#define GET_OP_CLASSES
54#include "mlir/Dialect/Tensor/IR/TensorOps.h.inc"
55
56//===----------------------------------------------------------------------===//
57// Tensor Dialect Helpers
58//===----------------------------------------------------------------------===//
59
60namespace mlir {
61namespace tensor {
62
63/// Returns true if `target` is a ranked tensor type that preserves static
64/// information available in the `source` ranked tensor type.
65bool preservesStaticInformation(Type source, Type target);
66
67/// Determines whether tensor::CastOp casts to a more dynamic version of the
68/// source tensor. This is useful to fold a tensor.cast into a consuming op and
69/// implement canonicalization patterns for ops in different dialects that may
70/// consume the results of tensor.cast operations. Such foldable tensor.cast
71/// operations are typically inserted as `extract_slice` ops and are
72/// canonicalized, to preserve the type compatibility of their uses.
73///
74/// Returns true when all conditions are met:
75/// 1. source and result are ranked tensors with same element type and rank.
76/// 2. the tensor type has more static information than the result
77///
78/// Example:
79/// ```mlir
80/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
81/// %2 = consumer %1 ... : tensor<?x?xf32> ...
82/// ```
83///
84/// folds into:
85///
86/// ```mlir
87/// %2 = consumer %0 ... : tensor<8x16xf32> ...
88/// ```
89bool canFoldIntoConsumerOp(CastOp castOp);
90
91/// Determines whether the tensor::CastOp casts to a more static version of the
92/// source tensor. This is useful to fold into a producing op and implement
93/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
94/// being from different dialects. Returns true when all conditions are met:
95/// 1. source and result and ranked tensors with same element type and rank.
96/// 2. the result type has more static information than the source.
97///
98/// Example:
99/// ```mlir
100/// %1 = producer ... : tensor<?x?xf32>
101/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
102/// ```
103///
104/// can be canonicalized to :
105///
106/// ```mlir
107/// %2 = producer ... : tensor<8x16xf32>
108/// ```
109/// Not all ops might be canonicalizable this way, but for those that can be,
110/// this method provides a check that it is worth doing the canonicalization.
111bool canFoldIntoProducerOp(CastOp castOp);
112
113/// Return true if any of the operands of `op` is a CastOp that can be folded
114/// into its consumer, i.e. `op`. This is effectively a convenience wrapper for
115/// `canFoldIntoProducerOp`.
116bool hasFoldableTensorCastOperand(Operation *op);
117
118/// Assuming that `op` contains at least one operand that is a foldable CastOp
119/// (i.e. `hasFoldableTensorCastOperand` returns true), calculate the updated
120/// operands.
122getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op,
123 SmallVector<Type> &newResTy);
124
125/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
126/// that can be folded.
127LogicalResult foldTensorCast(Operation *op);
128
129/// Return the dimension of the given tensor value.
130OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value,
131 int64_t dim);
132
133/// Return the dimensions of the given tensor value.
134SmallVector<OpFoldResult> getMixedSizes(OpBuilder &builder, Location loc,
135 Value value);
136
137/// Create a rank-reducing ExtractSliceOp @[0 .. 0] with strides [1 .. 1] and
138/// appropriate sizes (i.e. `tensor.getSizes()`) to reduce the rank of `tensor`
139/// to that of `targetType`.
140Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc,
141 Value tensor,
142 RankedTensorType targetType);
143
144/// Create a rank-reducing InsertSliceOp @[0 .. 0] with strides [1 .. 1] and
145/// appropriate sizes (i.e. `dest.getSizes()`). The result is a new tensor with
146/// rank increased to that of `dest`, obtained by inserting `tensor` into `dest`
147/// at the canonical [0 .. 0] position.
148Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc,
149 Value tensor, Value dest);
150
151/// This is a helper function for DestinationStyleOpInterface. If there is a
152/// destination operand for the given OpResult, return that operand. Otherwise,
153/// return an empty tensor (`tensor.empty`) with the shape of the OpResult.
154/// Dynamic dimensions are queried via ReifyRankedShapedTypeOpInterface.
155FailureOr<Value> getOrCreateDestination(OpBuilder &b, Location loc,
156 OpResult opResult);
157
158/// This is a helper function for DestinationStyleOpInterface. Get or create
159/// destinations for every tensor OpResult of the given op.
160LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op,
162
163/// Tests if types are the same when ignoring encoding on ranked tensors.
164bool isSameTypeWithoutEncoding(Type tp1, Type tp2);
165
166/// Function to control the folding of constant and extract slice.
167using ControlConstantExtractSliceFusionFn = std::function<bool(ExtractSliceOp)>;
168
169/// Patterns to fold the extract slice op with its constant operand.
172 const ControlConstantExtractSliceFusionFn &controlFn =
173 [](ExtractSliceOp op) {
174 // Disable by default because the folding can generate a large
175 // constant tensor, which would affect the compile time and storage.
176 return false;
177 });
178
179/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
180/// source tensor.
182
183} // namespace tensor
184} // namespace mlir
185
186#endif // MLIR_DIALECT_TENSOR_IR_TENSOR_H_
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
Definition TensorOps.cpp:75
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Definition Tensor.h:167
Include the generated interface declarations.
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
const FrozenRewritePatternSet & patterns