MLIR  16.0.0git
Utils.h
Go to the documentation of this file.
1 //===- Utils.h - General Arithmetic transformation utilities ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This header file defines prototypes for various transformation utilities for
10 // the Arithmetic dialect. These are not passes by themselves but are used
11 // either by passes, optimization sequences, or in turn by other transformation
12 // utilities.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H
17 #define MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H
18 
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/IR/Value.h"
23 
24 namespace mlir {
25 
26 /// Matches a ConstantIndexOp.
27 detail::op_matcher<arith::ConstantIndexOp> matchConstantIndex();
28 
29 /// Detects the `values` produced by a ConstantIndexOp and places the new
30 /// constant in place of the corresponding sentinel value.
31 void canonicalizeSubViewPart(SmallVectorImpl<OpFoldResult> &values,
32  function_ref<bool(int64_t)> isDynamic);
33 
34 llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank,
35  ArrayRef<int64_t> shape);
36 
37 /// Pattern to rewrite a subview op with constant arguments.
38 template <typename OpType, typename ResultTypeFunc, typename CastOpFunc>
40  : public OpRewritePattern<OpType> {
41 public:
43 
45  PatternRewriter &rewriter) const override {
46  // No constant operand, just return;
47  if (llvm::none_of(op.getOperands(), [](Value operand) {
48  return matchPattern(operand, matchConstantIndex());
49  }))
50  return failure();
51 
52  // At least one of offsets/sizes/strides is a new constant.
53  // Form the new list of operands and constant attributes from the existing.
54  SmallVector<OpFoldResult> mixedOffsets(op.getMixedOffsets());
55  SmallVector<OpFoldResult> mixedSizes(op.getMixedSizes());
56  SmallVector<OpFoldResult> mixedStrides(op.getMixedStrides());
57  canonicalizeSubViewPart(mixedOffsets, ShapedType::isDynamicStrideOrOffset);
58  canonicalizeSubViewPart(mixedSizes, ShapedType::isDynamic);
59  canonicalizeSubViewPart(mixedStrides, ShapedType::isDynamicStrideOrOffset);
60 
61  // Create the new op in canonical form.
62  ResultTypeFunc resultTypeFunc;
63  auto resultType =
64  resultTypeFunc(op, mixedOffsets, mixedSizes, mixedStrides);
65  if (!resultType)
66  return failure();
67  auto newOp =
68  rewriter.create<OpType>(op.getLoc(), resultType, op.getSource(),
69  mixedOffsets, mixedSizes, mixedStrides);
70  CastOpFunc func;
71  func(rewriter, op, newOp);
72 
73  return success();
74  }
75 };
76 
77 /// Converts an OpFoldResult to a Value. Returns the fold result if it casts to
78 /// a Value or creates a ConstantIndexOp if it casts to an IntegerAttribute.
79 /// Other attribute types are not supported.
81  OpFoldResult ofr);
82 
83 /// Create a cast from an index-like value (index or integer) to another
84 /// index-like value. If the value type and the target type are the same, it
85 /// returns the original value.
87  Type targetType, Value value);
88 
89 /// Similar to the other overload, but converts multiple OpFoldResults into
90 /// Values.
93  ArrayRef<OpFoldResult> valueOrAttrVec);
94 
95 /// Helper struct to build simple arithmetic quantities with minimal type
96 /// inference support.
97 struct ArithBuilder {
98  ArithBuilder(OpBuilder &b, Location loc) : b(b), loc(loc) {}
99 
100  Value _and(Value lhs, Value rhs);
101  Value add(Value lhs, Value rhs);
102  Value mul(Value lhs, Value rhs);
103  Value select(Value cmp, Value lhs, Value rhs);
104  Value sgt(Value lhs, Value rhs);
105  Value slt(Value lhs, Value rhs);
106 
107 private:
108  OpBuilder &b;
109  Location loc;
110 };
111 } // namespace mlir
112 
113 #endif // MLIR_DIALECT_ARITHMETIC_UTILS_UTILS_H
Include the generated interface declarations.
Pattern to rewrite a subview op with constant arguments.
Definition: Utils.h:39
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
llvm::function_ref< Fn > function_ref
Definition: LLVM.h:149
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
ArithBuilder(OpBuilder &b, Location loc)
Definition: Utils.h:98
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:97
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override
Definition: Utils.h:44
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:355
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:332
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
This class helps build Operations.
Definition: Builders.h:192
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22