MLIR  20.0.0git
Utils.h
Go to the documentation of this file.
1 //===- Utils.h - General Arith transformation utilities ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various transformation utilities for
10 // the Arith dialect. These are not passes by themselves but are used
11 // either by passes, optimization sequences, or in turn by other transformation
12 // utilities.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_DIALECT_ARITH_UTILS_UTILS_H
17 #define MLIR_DIALECT_ARITH_UTILS_UTILS_H
18 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Value.h"
23 #include "llvm/ADT/ArrayRef.h"
24 
25 namespace mlir {
26 
28 
29 /// Infer the output shape for a {memref|tensor}.expand_shape when it is
30 /// possible to do so.
31 ///
32 /// Note: This should *only* be used to implement
33 /// `ExpandShapeOp::inferOutputShape` in both the memref and tensor namespaces.
34 /// If you need to infer the output shape you should use the static method of
35 /// `ExpandShapeOp` instead of calling this.
36 ///
37 /// `inputShape` is the shape of the tensor or memref being expanded as a
38 /// sequence of SSA values or constants. `expandedType` is the output shape of
39 /// the expand_shape operation. `reassociation` is the reassociation denoting
40 /// the output dims each input dim is mapped to.
41 ///
42 /// Returns the output shape in `outputShape` and `staticOutputShape`, following
43 /// the conventions for the output_shape and static_output_shape inputs to the
44 /// expand_shape ops.
45 std::optional<SmallVector<OpFoldResult>>
46 inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType,
47  ArrayRef<ReassociationIndices> reassociation,
48  ArrayRef<OpFoldResult> inputShape);
49 
50 /// Matches a ConstantIndexOp.
52 
53 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
54  ArrayRef<int64_t> shape);
55 
56 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
57 /// a Value or creates a ConstantOp if it casts to an Integer Attribute.
58 /// Other attribute types are not supported.
60  OpFoldResult ofr);
61 
62 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
63 /// a Value or creates a ConstantIndexOp if it casts to an Integer Attribute.
64 /// Other attribute types are not supported.
66  OpFoldResult ofr);
67 
68 /// Similar to the other overload, but converts multiple OpFoldResults into
69 /// Values.
72  ArrayRef<OpFoldResult> valueOrAttrVec);
73 
74 /// Create a cast from an index-like value (index or integer) to another
75 /// index-like value. If the value type and the target type are the same, it
76 /// returns the original value.
78  Type targetType, Value value);
79 
80 /// Converts a scalar value `operand` to type `toType`. If the value doesn't
81 /// convert, a warning will be issued and the operand is returned as is (which
82 /// will presumably yield a verification issue downstream).
84  Type toType, bool isUnsignedCast);
85 
86 /// Create a constant of type `type` at location `loc` whose value is `value`
87 /// (an APInt or APFloat whose type must match the element type of `type`).
88 /// If `type` is a shaped type, create a splat constant of the given value.
89 /// Constants are folded if possible.
91  const APInt &value);
93  int64_t value);
95  const APFloat &value);
96 
97 /// Returns the int type of the integer in ofr.
98 /// Other attribute types are not supported.
100 
101 /// Helper struct to build simple arithmetic quantities with minimal type
102 /// inference support.
103 struct ArithBuilder {
104  ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
105 
106  Value _and(Value lhs, Value rhs);
107  Value add(Value lhs, Value rhs);
108  Value sub(Value lhs, Value rhs);
109  Value mul(Value lhs, Value rhs);
110  Value select(Value cmp, Value lhs, Value rhs);
111  Value sgt(Value lhs, Value rhs);
112  Value slt(Value lhs, Value rhs);
113 
114 private:
115  OpBuilder &b;
116  Location loc;
117 };
118 
119 namespace arith {
120 
121 // Build the product of a sequence.
122 // If values = (v0, v1, ..., vn) than the returned
123 // value is v0 * v1 * ... * vn.
124 // All values must have the same type.
125 //
126 // The version without `resultType` must contain at least one element in values.
127 // Then the result will have the same type as the elements in `values`.
128 // If `values` is empty in the version with `resultType` returns 1 with type
129 // `resultType`.
130 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values);
131 Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
132  Type resultType);
133 
134 // Map strings to float types.
135 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
136 
137 } // namespace arith
138 } // namespace mlir
139 
140 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:216
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:361
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:346
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:87
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:103
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:325
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:312
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:335
ArithBuilder(OpBuilder &b, Location loc)
Definition: Utils.h:104
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:340
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:315
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:330
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:320
The matcher that matches a certain kind of op.
Definition: Matchers.h:283