MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/SmallBitVector.h"
18 #include <numeric>
19 
20 using namespace mlir;
21 
22 std::optional<SmallVector<OpFoldResult>>
24  ShapedType expandedType,
25  ArrayRef<ReassociationIndices> reassociation,
26  ArrayRef<OpFoldResult> inputShape) {
27 
28  SmallVector<Value> outputShapeValues;
29  SmallVector<int64_t> outputShapeInts;
30  // For zero-rank inputs, all dims in result shape are unit extent.
31  if (inputShape.empty()) {
32  outputShapeInts.resize(expandedType.getRank(), 1);
33  return getMixedValues(outputShapeInts, outputShapeValues, b);
34  }
35 
36  // Check for all static shapes.
37  if (expandedType.hasStaticShape()) {
38  ArrayRef<int64_t> staticShape = expandedType.getShape();
39  outputShapeInts.assign(staticShape.begin(), staticShape.end());
40  return getMixedValues(outputShapeInts, outputShapeValues, b);
41  }
42 
43  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
44  for (const auto &it : llvm::enumerate(reassociation)) {
45  ReassociationIndices indexGroup = it.value();
46 
47  int64_t indexGroupStaticSizesProductInt = 1;
48  bool foundDynamicShape = false;
49  for (int64_t index : indexGroup) {
50  int64_t outputDimSize = expandedType.getDimSize(index);
51  // Cannot infer expanded shape with multiple dynamic dims in the
52  // same reassociation group!
53  if (ShapedType::isDynamic(outputDimSize)) {
54  if (foundDynamicShape)
55  return std::nullopt;
56  foundDynamicShape = true;
57  } else {
58  outputShapeInts[index] = outputDimSize;
59  indexGroupStaticSizesProductInt *= outputDimSize;
60  }
61  }
62  if (!foundDynamicShape)
63  continue;
64 
65  int64_t inputIndex = it.index();
66  // Call get<Value>() under the assumption that we're not casting
67  // dynamism.
68  Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
69  Value indexGroupStaticSizesProduct =
70  arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
71  Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
72  loc, indexGroupSize, indexGroupStaticSizesProduct);
73  outputShapeValues.push_back(dynamicDimSize);
74  }
75 
76  if ((int64_t)outputShapeValues.size() !=
77  llvm::count(outputShapeInts, ShapedType::kDynamic))
78  return std::nullopt;
79 
80  return getMixedValues(outputShapeInts, outputShapeValues, b);
81 }
82 
83 /// Matches a ConstantIndexOp.
84 /// TODO: This should probably just be a general matcher that uses matchConstant
85 /// and checks the operation for an index type.
88 }
89 
90 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
91  ArrayRef<int64_t> shape) {
92  llvm::SmallBitVector dimsToProject(shape.size());
93  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
94  if (shape[pos] == 1) {
95  dimsToProject.set(pos);
96  --rank;
97  }
98  }
99  return dimsToProject;
100 }
101 
103  OpFoldResult ofr) {
104  if (auto value = dyn_cast_if_present<Value>(ofr))
105  return value;
106  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
107  return arith::ConstantOp::create(
108  b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
109 }
110 
112  OpFoldResult ofr) {
113  if (auto value = dyn_cast_if_present<Value>(ofr))
114  return value;
115  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
116  return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
117 }
118 
120  Type targetType, Value value) {
121  if (targetType == value.getType())
122  return value;
123 
124  bool targetIsIndex = targetType.isIndex();
125  bool valueIsIndex = value.getType().isIndex();
126  if (targetIsIndex ^ valueIsIndex)
127  return arith::IndexCastOp::create(b, loc, targetType, value);
128 
129  auto targetIntegerType = dyn_cast<IntegerType>(targetType);
130  auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
131  assert(targetIntegerType && valueIntegerType &&
132  "unexpected cast between types other than integers and index");
133  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
134 
135  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
136  return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
137  return arith::TruncIOp::create(b, loc, targetIntegerType, value);
138 }
139 
141  IntegerType toType, bool isUnsigned) {
142  // If operand is floating point, cast directly to the int type.
143  if (isa<FloatType>(operand.getType())) {
144  if (isUnsigned)
145  return arith::FPToUIOp::create(b, toType, operand);
146  return arith::FPToSIOp::create(b, toType, operand);
147  }
148  // Cast index operands directly to the int type.
149  if (operand.getType().isIndex())
150  return arith::IndexCastOp::create(b, toType, operand);
151  if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
152  // Either extend or truncate.
153  if (toType.getWidth() > fromIntType.getWidth()) {
154  if (isUnsigned)
155  return arith::ExtUIOp::create(b, toType, operand);
156  return arith::ExtSIOp::create(b, toType, operand);
157  }
158  if (toType.getWidth() < fromIntType.getWidth())
159  return arith::TruncIOp::create(b, toType, operand);
160  return operand;
161  }
162 
163  return {};
164 }
165 
167  FloatType toType, bool isUnsigned) {
168  // If operand is integer, cast directly to the float type.
169  // Note that it is unclear how to cast from BF16<->FP16.
170  if (isa<IntegerType>(operand.getType())) {
171  if (isUnsigned)
172  return arith::UIToFPOp::create(b, toType, operand);
173  return arith::SIToFPOp::create(b, toType, operand);
174  }
175  if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
176  if (toType.getWidth() > fromFpTy.getWidth())
177  return arith::ExtFOp::create(b, toType, operand);
178  if (toType.getWidth() < fromFpTy.getWidth())
179  return arith::TruncFOp::create(b, toType, operand);
180  return operand;
181  }
182 
183  return {};
184 }
185 
187  ComplexType targetType,
188  bool isUnsigned) {
189  if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
190  if (isa<FloatType>(targetType.getElementType()) &&
191  isa<FloatType>(fromComplexType.getElementType())) {
192  Value real = complex::ReOp::create(b, operand);
193  Value imag = complex::ImOp::create(b, operand);
194  Type targetETy = targetType.getElementType();
195  if (targetType.getElementType().getIntOrFloatBitWidth() <
196  fromComplexType.getElementType().getIntOrFloatBitWidth()) {
197  real = arith::TruncFOp::create(b, targetETy, real);
198  imag = arith::TruncFOp::create(b, targetETy, imag);
199  } else {
200  real = arith::ExtFOp::create(b, targetETy, real);
201  imag = arith::ExtFOp::create(b, targetETy, imag);
202  }
203  return complex::CreateOp::create(b, targetType, real, imag);
204  }
205  }
206 
207  if (isa<FloatType>(operand.getType())) {
208  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
209  auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
210  Value from = operand;
211  if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
212  from = arith::ExtFOp::create(b, toFpTy, from);
213  }
214  if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
215  from = arith::TruncFOp::create(b, toFpTy, from);
216  }
218  b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
219  return complex::CreateOp::create(b, targetType, from, zero);
220  }
221 
222  if (isa<IntegerType>(operand.getType())) {
223  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
224  Value from = operand;
225  if (isUnsigned) {
226  from = arith::UIToFPOp::create(b, toFpTy, from);
227  } else {
228  from = arith::SIToFPOp::create(b, toFpTy, from);
229  }
231  b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
232  return complex::CreateOp::create(b, targetType, from, zero);
233  }
234 
235  return {};
236 }
237 
239  Type toType, bool isUnsignedCast) {
240  if (operand.getType() == toType)
241  return operand;
242  ImplicitLocOpBuilder ib(loc, b);
243  Value result;
244  if (auto intTy = dyn_cast<IntegerType>(toType)) {
245  result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
246  } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
247  result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
248  } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
249  result =
250  convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
251  }
252 
253  if (result)
254  return result;
255 
256  emitWarning(loc) << "could not cast operand of type " << operand.getType()
257  << " to " << toType;
258  return operand;
259 }
260 
263  ArrayRef<OpFoldResult> valueOrAttrVec) {
264  return llvm::to_vector<4>(
265  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
266  return getValueOrCreateConstantIndexOp(b, loc, value);
267  }));
268 }
269 
271  Type type, const APInt &value) {
272  TypedAttr attr;
273  if (isa<IntegerType>(type)) {
274  attr = builder.getIntegerAttr(type, value);
275  } else {
276  auto vecTy = cast<ShapedType>(type);
277  attr = SplatElementsAttr::get(vecTy, value);
278  }
279 
280  return arith::ConstantOp::create(builder, loc, attr);
281 }
282 
284  Type type, int64_t value) {
285  unsigned elementBitWidth = 0;
286  if (auto intTy = dyn_cast<IntegerType>(type))
287  elementBitWidth = intTy.getWidth();
288  else
289  elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
290 
291  return createScalarOrSplatConstant(builder, loc, type,
292  APInt(elementBitWidth, value));
293 }
294 
296  Type type, const APFloat &value) {
297  if (isa<FloatType>(type))
298  return builder.createOrFold<arith::ConstantOp>(
299  loc, type, builder.getFloatAttr(type, value));
300  TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
301  return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
302 }
303 
305  if (auto value = dyn_cast_if_present<Value>(ofr))
306  return value.getType();
307  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
308  return attr.getType();
309 }
310 
312  return arith::AndIOp::create(b, loc, lhs, rhs);
313 }
315  if (isa<FloatType>(lhs.getType()))
316  return arith::AddFOp::create(b, loc, lhs, rhs);
317  return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
318 }
320  if (isa<FloatType>(lhs.getType()))
321  return arith::SubFOp::create(b, loc, lhs, rhs);
322  return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
323 }
325  if (isa<FloatType>(lhs.getType()))
326  return arith::MulFOp::create(b, loc, lhs, rhs);
327  return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
328 }
330  if (isa<FloatType>(lhs.getType()))
331  return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
332  return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
333 }
335  if (isa<FloatType>(lhs.getType()))
336  return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
337  return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
338 }
340  return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
341 }
342 
343 namespace mlir::arith {
344 
346  return createProduct(builder, loc, values, values.front().getType());
347 }
348 
350  Type resultType) {
351  Value one = ConstantOp::create(builder, loc, resultType,
352  builder.getOneAttr(resultType));
353  ArithBuilder arithBuilder(builder, loc);
354  return std::accumulate(
355  values.begin(), values.end(), one,
356  [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
357 }
358 
359 /// Map strings to float types.
360 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
361  Builder b(ctx);
363  .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
364  .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
365  .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
366  .Case("f8E5M2", b.getType<Float8E5M2Type>())
367  .Case("f8E4M3", b.getType<Float8E4M3Type>())
368  .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
369  .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
370  .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
371  .Case("f8E3M4", b.getType<Float8E3M4Type>())
372  .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
373  .Case("bf16", b.getType<BFloat16Type>())
374  .Case("f16", b.getType<Float16Type>())
375  .Case("f32", b.getType<Float32Type>())
376  .Case("f64", b.getType<Float64Type>())
377  .Case("f80", b.getType<Float80Type>())
378  .Case("f128", b.getType<Float128Type>())
379  .Default(std::nullopt);
380 }
381 
382 } // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition: Utils.cpp:186
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition: Utils.cpp:140
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition: Utils.cpp:166
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:341
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition: ArithOps.cpp:330
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:360
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:345
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:270
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:23
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:90
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:86
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:103
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:324
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:311
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:334
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:339
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:314
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:329
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:319
The matcher that matches a certain kind of op.
Definition: Matchers.h:283