MLIR  16.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "llvm/ADT/SmallBitVector.h"
16 
17 using namespace mlir;
18 
19 /// Matches a ConstantIndexOp.
20 /// TODO: This should probably just be a general matcher that uses matchConstant
21 /// and checks the operation for an index type.
24 }
25 
26 /// Detects the `values` produced by a ConstantIndexOp and places the new
27 /// constant in place of the corresponding sentinel value.
30  llvm::function_ref<bool(int64_t)> isDynamic) {
31  for (OpFoldResult &ofr : values) {
32  if (ofr.is<Attribute>())
33  continue;
34  // Newly static, move from Value to constant.
35  if (auto cstOp =
36  ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>())
37  ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value());
38  }
39 }
40 
41 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
42  ArrayRef<int64_t> shape) {
43  llvm::SmallBitVector dimsToProject(shape.size());
44  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
45  if (shape[pos] == 1) {
46  dimsToProject.set(pos);
47  --rank;
48  }
49  }
50  return dimsToProject;
51 }
52 
54  OpFoldResult ofr) {
55  if (auto value = ofr.dyn_cast<Value>())
56  return value;
57  auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
58  assert(attr && "expect the op fold result casts to an integer attribute");
59  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
60 }
61 
63  Type targetType, Value value) {
64  if (targetType == value.getType())
65  return value;
66 
67  bool targetIsIndex = targetType.isIndex();
68  bool valueIsIndex = value.getType().isIndex();
69  if (targetIsIndex ^ valueIsIndex)
70  return b.create<arith::IndexCastOp>(loc, targetType, value);
71 
72  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
73  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
74  assert(targetIntegerType && valueIntegerType &&
75  "unexpected cast between types other than integers and index");
76  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
77 
78  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
79  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
80  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
81 }
82 
85  ArrayRef<OpFoldResult> valueOrAttrVec) {
86  return llvm::to_vector<4>(
87  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
88  return getValueOrCreateConstantIndexOp(b, loc, value);
89  }));
90 }
91 
93  return b.create<arith::AndIOp>(loc, lhs, rhs);
94 }
96  if (lhs.getType().isa<IntegerType>())
97  return b.create<arith::AddIOp>(loc, lhs, rhs);
98  return b.create<arith::AddFOp>(loc, lhs, rhs);
99 }
101  if (lhs.getType().isa<IntegerType>())
102  return b.create<arith::MulIOp>(loc, lhs, rhs);
103  return b.create<arith::MulFOp>(loc, lhs, rhs);
104 }
106  if (lhs.getType().isa<IndexType, IntegerType>())
107  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
108  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
109 }
111  if (lhs.getType().isa<IndexType, IntegerType>())
112  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
113  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
114 }
116  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
117 }
Include the generated interface declarations.
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:110
This class represents a single result from folding an operation.
Definition: OpDefinition.h:235
static constexpr const bool value
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
U dyn_cast() const
Definition: Types.h:270
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:115
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
bool isIndex() const
Definition: Types.cpp:28
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:105
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:92
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
The matcher that matches a certain kind of op.
Definition: Matchers.h:170
Type getType() const
Return the type of this value.
Definition: Value.h:118
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:95
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
bool isa() const
Definition: Types.h:254
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:100
This class helps build Operations.
Definition: Builders.h:192
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22