MLIR  17.0.0git
Utils.h
Go to the documentation of this file.
1 //===- Utils.h - General Arith transformation utilities ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various transformation utilities for
10 // the Arith dialect. These are not passes by themselves but are used
11 // either by passes, optimization sequences, or in turn by other transformation
12 // utilities.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_DIALECT_ARITH_UTILS_UTILS_H
17 #define MLIR_DIALECT_ARITH_UTILS_UTILS_H
18 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Value.h"
23 
24 namespace mlir {
25 
26 /// Matches a ConstantIndexOp.
27 detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
28 
29 /// Detects the `values` produced by a ConstantIndexOp and places the new
30 /// constant in place of the corresponding sentinel value.
32  function_ref<bool(int64_t)> isDynamic);
33 
34 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
35  ArrayRef<int64_t> shape);
36 
37 /// Pattern to rewrite a subview op with constant arguments.
38 template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
40  : public OpRewritePattern<OpType> {
41 public:
43 
45  PatternRewriter &rewriter) const override {
46  // No constant operand, just return;
47  if (llvm::none_of(op.getOperands(), [](Value operand) {
48  return matchPattern(operand, matchConstantIndex());
49  }))
50  return failure();
51 
52  // At least one of offsets/sizes/strides is a new constant.
53  // Form the new list of operands and constant attributes from the existing.
54  SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
55  SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
56  SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
57  canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamic);
58  canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
59  canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamic);
60 
61  // Create the new op in canonical form.
62  ResultTypeFunc resultTypeFunc;
63  auto resultType =
64  resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
65  if (!resultType)
66  return failure();
67  auto newOp =
68  rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
69  mixedOffsets, mixedSizes, mixedStrides);
70  CastOpFunc func;
71  func(rewriter, op, newOp);
72 
73  return success();
74  }
75 };
76 
77 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
78 /// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
79 /// Other attribute types are not supported.
80 Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
81  OpFoldResult ofr);
82 
83 /// Create a cast from an index-like value (index or integer) to another
84 /// index-like value. If the value type and the target type are the same, it
85 /// returns the original value.
86 Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc,
87  Type targetType, Value value);
88 
89 /// Similar to the other overload, but converts multiple OpFoldResults into
90 /// Values.
91 SmallVector<Value>
92 getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc,
93  ArrayRef<OpFoldResult> valueOrAttrVec);
94 
95 /// Converts a scalar value `operand` to type `toType`. If the value doesn't
96 /// convert, a warning will be issued and the operand is returned as is (which
97 /// will presumably yield a verification issue downstream).
98 Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand,
99  Type toType, bool isUnsignedCast);
100 
101 /// Helper struct to build simple arithmetic quantities with minimal type
102 /// inference support.
103 struct ArithBuilder {
104  ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
105 
106  Value _and(Value lhs, Value rhs);
107  Value add(Value lhs, Value rhs);
108  Value sub(Value lhs, Value rhs);
109  Value mul(Value lhs, Value rhs);
110  Value select(Value cmp, Value lhs, Value rhs);
111  Value sgt(Value lhs, Value rhs);
112  Value slt(Value lhs, Value rhs);
113 
114 private:
115  OpBuilder &b;
116  Location loc;
117 };
118 } // namespace mlir
119 
120 #endif // MLIR_DIALECT_ARITH_UTILS_UTILS_H
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:199
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:40
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override
Definition: Utils.h:44
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:621
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:83
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:103
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:149
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:136
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:159
ArithBuilder(OpBuilder &b, Location loc)
Definition: Utils.h:104
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:164
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:139
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:154
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:144
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:357