MLIR  19.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/SmallBitVector.h"
18 #include <numeric>
19 
20 using namespace mlir;
21 
22 /// Matches a ConstantIndexOp.
23 /// TODO: This should probably just be a general matcher that uses matchConstant
24 /// and checks the operation for an index type.
27 }
28 
29 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
30  ArrayRef<int64_t> shape) {
31  llvm::SmallBitVector dimsToProject(shape.size());
32  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
33  if (shape[pos] == 1) {
34  dimsToProject.set(pos);
35  --rank;
36  }
37  }
38  return dimsToProject;
39 }
40 
42  OpFoldResult ofr) {
43  if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
44  return value;
45  auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
46  assert(attr && "expect the op fold result casts to an integer attribute");
47  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
48 }
49 
51  Type targetType, Value value) {
52  if (targetType == value.getType())
53  return value;
54 
55  bool targetIsIndex = targetType.isIndex();
56  bool valueIsIndex = value.getType().isIndex();
57  if (targetIsIndex ^ valueIsIndex)
58  return b.create<arith::IndexCastOp>(loc, targetType, value);
59 
60  auto targetIntegerType = dyn_cast<IntegerType>(targetType);
61  auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
62  assert(targetIntegerType && valueIntegerType &&
63  "unexpected cast between types other than integers and index");
64  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
65 
66  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
67  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
68  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
69 }
70 
72  IntegerType toType, bool isUnsigned) {
73  // If operand is floating point, cast directly to the int type.
74  if (isa<FloatType>(operand.getType())) {
75  if (isUnsigned)
76  return b.create<arith::FPToUIOp>(toType, operand);
77  return b.create<arith::FPToSIOp>(toType, operand);
78  }
79  // Cast index operands directly to the int type.
80  if (operand.getType().isIndex())
81  return b.create<arith::IndexCastOp>(toType, operand);
82  if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
83  // Either extend or truncate.
84  if (toType.getWidth() > fromIntType.getWidth()) {
85  if (isUnsigned)
86  return b.create<arith::ExtUIOp>(toType, operand);
87  return b.create<arith::ExtSIOp>(toType, operand);
88  }
89  if (toType.getWidth() < fromIntType.getWidth())
90  return b.create<arith::TruncIOp>(toType, operand);
91  return operand;
92  }
93 
94  return {};
95 }
96 
98  FloatType toType, bool isUnsigned) {
99  // If operand is integer, cast directly to the float type.
100  // Note that it is unclear how to cast from BF16<->FP16.
101  if (isa<IntegerType>(operand.getType())) {
102  if (isUnsigned)
103  return b.create<arith::UIToFPOp>(toType, operand);
104  return b.create<arith::SIToFPOp>(toType, operand);
105  }
106  if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
107  if (toType.getWidth() > fromFpTy.getWidth())
108  return b.create<arith::ExtFOp>(toType, operand);
109  if (toType.getWidth() < fromFpTy.getWidth())
110  return b.create<arith::TruncFOp>(toType, operand);
111  return operand;
112  }
113 
114  return {};
115 }
116 
118  ComplexType targetType,
119  bool isUnsigned) {
120  if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
121  if (isa<FloatType>(targetType.getElementType()) &&
122  isa<FloatType>(fromComplexType.getElementType())) {
123  Value real = b.create<complex::ReOp>(operand);
124  Value imag = b.create<complex::ImOp>(operand);
125  Type targetETy = targetType.getElementType();
126  if (targetType.getElementType().getIntOrFloatBitWidth() <
127  fromComplexType.getElementType().getIntOrFloatBitWidth()) {
128  real = b.create<arith::TruncFOp>(targetETy, real);
129  imag = b.create<arith::TruncFOp>(targetETy, imag);
130  } else {
131  real = b.create<arith::ExtFOp>(targetETy, real);
132  imag = b.create<arith::ExtFOp>(targetETy, imag);
133  }
134  return b.create<complex::CreateOp>(targetType, real, imag);
135  }
136  }
137 
138  if (dyn_cast<FloatType>(operand.getType())) {
139  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
140  auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
141  Value from = operand;
142  if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
143  from = b.create<arith::ExtFOp>(toFpTy, from);
144  }
145  if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
146  from = b.create<arith::TruncFOp>(toFpTy, from);
147  }
149  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
150  return b.create<complex::CreateOp>(targetType, from, zero);
151  }
152 
153  if (dyn_cast<IntegerType>(operand.getType())) {
154  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
155  Value from = operand;
156  if (isUnsigned) {
157  from = b.create<arith::UIToFPOp>(toFpTy, from);
158  } else {
159  from = b.create<arith::SIToFPOp>(toFpTy, from);
160  }
162  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
163  return b.create<complex::CreateOp>(targetType, from, zero);
164  }
165 
166  return {};
167 }
168 
170  Type toType, bool isUnsignedCast) {
171  if (operand.getType() == toType)
172  return operand;
173  ImplicitLocOpBuilder ib(loc, b);
174  Value result;
175  if (auto intTy = dyn_cast<IntegerType>(toType)) {
176  result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
177  } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
178  result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
179  } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
180  result =
181  convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
182  }
183 
184  if (result)
185  return result;
186 
187  emitWarning(loc) << "could not cast operand of type " << operand.getType()
188  << " to " << toType;
189  return operand;
190 }
191 
194  ArrayRef<OpFoldResult> valueOrAttrVec) {
195  return llvm::to_vector<4>(
196  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
197  return getValueOrCreateConstantIndexOp(b, loc, value);
198  }));
199 }
200 
202  Type type, const APInt &value) {
203  TypedAttr attr;
204  if (isa<IntegerType>(type)) {
205  attr = builder.getIntegerAttr(type, value);
206  } else {
207  auto vecTy = cast<ShapedType>(type);
208  attr = SplatElementsAttr::get(vecTy, value);
209  }
210 
211  return builder.create<arith::ConstantOp>(loc, attr);
212 }
213 
215  Type type, int64_t value) {
216  unsigned elementBitWidth = 0;
217  if (auto intTy = dyn_cast<IntegerType>(type))
218  elementBitWidth = intTy.getWidth();
219  else
220  elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
221 
222  return createScalarOrSplatConstant(builder, loc, type,
223  APInt(elementBitWidth, value));
224 }
225 
227  Type type, const APFloat &value) {
228  if (isa<FloatType>(type))
229  return builder.createOrFold<arith::ConstantOp>(
230  loc, type, builder.getFloatAttr(type, value));
231  TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
232  return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
233 }
234 
236  return b.create<arith::AndIOp>(loc, lhs, rhs);
237 }
239  if (isa<FloatType>(lhs.getType()))
240  return b.create<arith::AddFOp>(loc, lhs, rhs);
241  return b.create<arith::AddIOp>(loc, lhs, rhs);
242 }
244  if (isa<FloatType>(lhs.getType()))
245  return b.create<arith::SubFOp>(loc, lhs, rhs);
246  return b.create<arith::SubIOp>(loc, lhs, rhs);
247 }
249  if (isa<FloatType>(lhs.getType()))
250  return b.create<arith::MulFOp>(loc, lhs, rhs);
251  return b.create<arith::MulIOp>(loc, lhs, rhs);
252 }
254  if (isa<FloatType>(lhs.getType()))
255  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
256  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
257 }
259  if (isa<FloatType>(lhs.getType()))
260  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
261  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
262 }
264  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
265 }
266 
267 namespace mlir::arith {
268 
270  return createProduct(builder, loc, values, values.front().getType());
271 }
272 
274  Type resultType) {
275  Value one = builder.create<ConstantOp>(loc, resultType,
276  builder.getOneAttr(resultType));
277  ArithBuilder arithBuilder(builder, loc);
278  return std::accumulate(
279  values.begin(), values.end(), one,
280  [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
281 }
282 
283 } // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition: Utils.cpp:117
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition: Utils.cpp:71
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition: Utils.cpp:97
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:238
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:261
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:349
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:522
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:75
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:269
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:169
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:201
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:50
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:41
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:29
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:25
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:70
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:248
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:235
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:258
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:263
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:238
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:253
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:243
The matcher that matches a certain kind of op.
Definition: Matchers.h:224