MLIR  17.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "llvm/ADT/SmallBitVector.h"
16 
17 using namespace mlir;
18 
19 /// Matches a ConstantIndexOp.
20 /// TODO: This should probably just be a general matcher that uses matchConstant
21 /// and checks the operation for an index type.
24 }
25 
26 /// Detects the `values` produced by a ConstantIndexOp and places the new
27 /// constant in place of the corresponding sentinel value.
30  llvm::function_ref<bool(int64_t)> isDynamic) {
31  for (OpFoldResult &ofr : values) {
32  if (ofr.is<Attribute>())
33  continue;
34  // Newly static, move from Value to constant.
35  if (auto cstOp =
36  ofr.dyn_cast<Value>().getDefiningOp<arith::ConstantIndexOp>())
37  ofr = OpBuilder(cstOp).getIndexAttr(cstOp.value());
38  }
39 }
40 
41 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
42  ArrayRef<int64_t> shape) {
43  llvm::SmallBitVector dimsToProject(shape.size());
44  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
45  if (shape[pos] == 1) {
46  dimsToProject.set(pos);
47  --rank;
48  }
49  }
50  return dimsToProject;
51 }
52 
54  OpFoldResult ofr) {
55  if (auto value = ofr.dyn_cast<Value>())
56  return value;
57  auto attr = ofr.dyn_cast<Attribute>().dyn_cast<IntegerAttr>();
58  assert(attr && "expect the op fold result casts to an integer attribute");
59  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
60 }
61 
63  Type targetType, Value value) {
64  if (targetType == value.getType())
65  return value;
66 
67  bool targetIsIndex = targetType.isIndex();
68  bool valueIsIndex = value.getType().isIndex();
69  if (targetIsIndex ^ valueIsIndex)
70  return b.create<arith::IndexCastOp>(loc, targetType, value);
71 
72  auto targetIntegerType = targetType.dyn_cast<IntegerType>();
73  auto valueIntegerType = value.getType().dyn_cast<IntegerType>();
74  assert(targetIntegerType && valueIntegerType &&
75  "unexpected cast between types other than integers and index");
76  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
77 
78  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
79  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
80  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
81 }
82 
84  Type toType, bool isUnsignedCast) {
85  if (operand.getType() == toType)
86  return operand;
87  if (auto toIntType = toType.dyn_cast<IntegerType>()) {
88  // If operand is floating point, cast directly to the int type.
89  if (operand.getType().isa<FloatType>()) {
90  if (isUnsignedCast)
91  return b.create<arith::FPToUIOp>(loc, toType, operand);
92  return b.create<arith::FPToSIOp>(loc, toType, operand);
93  }
94  // Cast index operands directly to the int type.
95  if (operand.getType().isIndex())
96  return b.create<arith::IndexCastOp>(loc, toType, operand);
97  if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
98  // Either extend or truncate.
99  if (toIntType.getWidth() > fromIntType.getWidth()) {
100  if (isUnsignedCast)
101  return b.create<arith::ExtUIOp>(loc, toType, operand);
102  return b.create<arith::ExtSIOp>(loc, toType, operand);
103  }
104  if (toIntType.getWidth() < fromIntType.getWidth())
105  return b.create<arith::TruncIOp>(loc, toType, operand);
106  }
107  } else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
108  // If operand is integer, cast directly to the float type.
109  // Note that it is unclear how to cast from BF16<->FP16.
110  if (operand.getType().isa<IntegerType>()) {
111  if (isUnsignedCast)
112  return b.create<arith::UIToFPOp>(loc, toFloatType, operand);
113  return b.create<arith::SIToFPOp>(loc, toFloatType, operand);
114  }
115  if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
116  if (toFloatType.getWidth() > fromFloatType.getWidth())
117  return b.create<arith::ExtFOp>(loc, toFloatType, operand);
118  if (toFloatType.getWidth() < fromFloatType.getWidth())
119  return b.create<arith::TruncFOp>(loc, toFloatType, operand);
120  }
121  }
122  emitWarning(loc) << "could not cast operand of type " << operand.getType()
123  << " to " << toType;
124  return operand;
125 }
126 
129  ArrayRef<OpFoldResult> valueOrAttrVec) {
130  return llvm::to_vector<4>(
131  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
132  return getValueOrCreateConstantIndexOp(b, loc, value);
133  }));
134 }
135 
137  return b.create<arith::AndIOp>(loc, lhs, rhs);
138 }
140  if (lhs.getType().isa<FloatType>())
141  return b.create<arith::AddFOp>(loc, lhs, rhs);
142  return b.create<arith::AddIOp>(loc, lhs, rhs);
143 }
145  if (lhs.getType().isa<FloatType>())
146  return b.create<arith::SubFOp>(loc, lhs, rhs);
147  return b.create<arith::SubIOp>(loc, lhs, rhs);
148 }
150  if (lhs.getType().isa<FloatType>())
151  return b.create<arith::MulFOp>(loc, lhs, rhs);
152  return b.create<arith::MulIOp>(loc, lhs, rhs);
153 }
155  if (lhs.getType().isa<FloatType>())
156  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
157  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
158 }
160  if (lhs.getType().isa<FloatType>())
161  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
162  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
163 }
165  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
166 }
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:109
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:199
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This class represents a single result from folding an operation.
Definition: OpDefinition.h:233
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:46
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:89
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:83
void canonicalizeSubViewPart(SmallVectorImpl< OpFoldResult > &values, function_ref< bool(int64_t)> isDynamic)
Detects the values produced by a ConstantIndexOp and places the new constant in place of the correspo...
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:62
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:53
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:41
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:22
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:149
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:136
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:159
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:164
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:139
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:154
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:144
The matcher that matches a certain kind of op.
Definition: Matchers.h:162