MLIR  20.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/SmallBitVector.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 std::optional<SmallVector<OpFoldResult>>
25  ShapedType expandedType,
26  ArrayRef<ReassociationIndices> reassociation,
27  ArrayRef<OpFoldResult> inputShape) {
28 
29  SmallVector<Value> outputShapeValues;
30  SmallVector<int64_t> outputShapeInts;
31  // For zero-rank inputs, all dims in result shape are unit extent.
32  if (inputShape.empty()) {
33  outputShapeInts.resize(expandedType.getRank(), 1);
34  return getMixedValues(outputShapeInts, outputShapeValues, b);
35  }
36 
37  // Check for all static shapes.
38  if (expandedType.hasStaticShape()) {
39  ArrayRef<int64_t> staticShape = expandedType.getShape();
40  outputShapeInts.assign(staticShape.begin(), staticShape.end());
41  return getMixedValues(outputShapeInts, outputShapeValues, b);
42  }
43 
44  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45  for (const auto &it : llvm::enumerate(reassociation)) {
46  ReassociationIndices indexGroup = it.value();
47 
48  int64_t indexGroupStaticSizesProductInt = 1;
49  bool foundDynamicShape = false;
50  for (int64_t index : indexGroup) {
51  int64_t outputDimSize = expandedType.getDimSize(index);
52  // Cannot infer expanded shape with multiple dynamic dims in the
53  // same reassociation group!
54  if (ShapedType::isDynamic(outputDimSize)) {
55  if (foundDynamicShape)
56  return std::nullopt;
57  foundDynamicShape = true;
58  } else {
59  outputShapeInts[index] = outputDimSize;
60  indexGroupStaticSizesProductInt *= outputDimSize;
61  }
62  }
63  if (!foundDynamicShape)
64  continue;
65 
66  int64_t inputIndex = it.index();
67  // Call get<Value>() under the assumption that we're not casting
68  // dynamism.
69  Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
70  Value indexGroupStaticSizesProduct =
71  b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
72  Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
73  loc, indexGroupSize, indexGroupStaticSizesProduct);
74  outputShapeValues.push_back(dynamicDimSize);
75  }
76 
77  if ((int64_t)outputShapeValues.size() !=
78  llvm::count(outputShapeInts, ShapedType::kDynamic))
79  return std::nullopt;
80 
81  return getMixedValues(outputShapeInts, outputShapeValues, b);
82 }
83 
84 /// Matches a ConstantIndexOp.
85 /// TODO: This should probably just be a general matcher that uses matchConstant
86 /// and checks the operation for an index type.
89 }
90 
91 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
92  ArrayRef<int64_t> shape) {
93  llvm::SmallBitVector dimsToProject(shape.size());
94  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95  if (shape[pos] == 1) {
96  dimsToProject.set(pos);
97  --rank;
98  }
99  }
100  return dimsToProject;
101 }
102 
104  OpFoldResult ofr) {
105  if (auto value = dyn_cast_if_present<Value>(ofr))
106  return value;
107  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108  return b.create<arith::ConstantOp>(
109  loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110 }
111 
113  OpFoldResult ofr) {
114  if (auto value = dyn_cast_if_present<Value>(ofr))
115  return value;
116  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
117  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
118 }
119 
121  Type targetType, Value value) {
122  if (targetType == value.getType())
123  return value;
124 
125  bool targetIsIndex = targetType.isIndex();
126  bool valueIsIndex = value.getType().isIndex();
127  if (targetIsIndex ^ valueIsIndex)
128  return b.create<arith::IndexCastOp>(loc, targetType, value);
129 
130  auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131  auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
132  assert(targetIntegerType && valueIntegerType &&
133  "unexpected cast between types other than integers and index");
134  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135 
136  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
138  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
139 }
140 
142  IntegerType toType, bool isUnsigned) {
143  // If operand is floating point, cast directly to the int type.
144  if (isa<FloatType>(operand.getType())) {
145  if (isUnsigned)
146  return b.create<arith::FPToUIOp>(toType, operand);
147  return b.create<arith::FPToSIOp>(toType, operand);
148  }
149  // Cast index operands directly to the int type.
150  if (operand.getType().isIndex())
151  return b.create<arith::IndexCastOp>(toType, operand);
152  if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
153  // Either extend or truncate.
154  if (toType.getWidth() > fromIntType.getWidth()) {
155  if (isUnsigned)
156  return b.create<arith::ExtUIOp>(toType, operand);
157  return b.create<arith::ExtSIOp>(toType, operand);
158  }
159  if (toType.getWidth() < fromIntType.getWidth())
160  return b.create<arith::TruncIOp>(toType, operand);
161  return operand;
162  }
163 
164  return {};
165 }
166 
168  FloatType toType, bool isUnsigned) {
169  // If operand is integer, cast directly to the float type.
170  // Note that it is unclear how to cast from BF16<->FP16.
171  if (isa<IntegerType>(operand.getType())) {
172  if (isUnsigned)
173  return b.create<arith::UIToFPOp>(toType, operand);
174  return b.create<arith::SIToFPOp>(toType, operand);
175  }
176  if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
177  if (toType.getWidth() > fromFpTy.getWidth())
178  return b.create<arith::ExtFOp>(toType, operand);
179  if (toType.getWidth() < fromFpTy.getWidth())
180  return b.create<arith::TruncFOp>(toType, operand);
181  return operand;
182  }
183 
184  return {};
185 }
186 
188  ComplexType targetType,
189  bool isUnsigned) {
190  if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
191  if (isa<FloatType>(targetType.getElementType()) &&
192  isa<FloatType>(fromComplexType.getElementType())) {
193  Value real = b.create<complex::ReOp>(operand);
194  Value imag = b.create<complex::ImOp>(operand);
195  Type targetETy = targetType.getElementType();
196  if (targetType.getElementType().getIntOrFloatBitWidth() <
197  fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198  real = b.create<arith::TruncFOp>(targetETy, real);
199  imag = b.create<arith::TruncFOp>(targetETy, imag);
200  } else {
201  real = b.create<arith::ExtFOp>(targetETy, real);
202  imag = b.create<arith::ExtFOp>(targetETy, imag);
203  }
204  return b.create<complex::CreateOp>(targetType, real, imag);
205  }
206  }
207 
208  if (dyn_cast<FloatType>(operand.getType())) {
209  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210  auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211  Value from = operand;
212  if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
213  from = b.create<arith::ExtFOp>(toFpTy, from);
214  }
215  if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
216  from = b.create<arith::TruncFOp>(toFpTy, from);
217  }
219  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
220  return b.create<complex::CreateOp>(targetType, from, zero);
221  }
222 
223  if (dyn_cast<IntegerType>(operand.getType())) {
224  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225  Value from = operand;
226  if (isUnsigned) {
227  from = b.create<arith::UIToFPOp>(toFpTy, from);
228  } else {
229  from = b.create<arith::SIToFPOp>(toFpTy, from);
230  }
232  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
233  return b.create<complex::CreateOp>(targetType, from, zero);
234  }
235 
236  return {};
237 }
238 
240  Type toType, bool isUnsignedCast) {
241  if (operand.getType() == toType)
242  return operand;
243  ImplicitLocOpBuilder ib(loc, b);
244  Value result;
245  if (auto intTy = dyn_cast<IntegerType>(toType)) {
246  result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
247  } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
248  result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
249  } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
250  result =
251  convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
252  }
253 
254  if (result)
255  return result;
256 
257  emitWarning(loc) << "could not cast operand of type " << operand.getType()
258  << " to " << toType;
259  return operand;
260 }
261 
264  ArrayRef<OpFoldResult> valueOrAttrVec) {
265  return llvm::to_vector<4>(
266  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
267  return getValueOrCreateConstantIndexOp(b, loc, value);
268  }));
269 }
270 
272  Type type, const APInt &value) {
273  TypedAttr attr;
274  if (isa<IntegerType>(type)) {
275  attr = builder.getIntegerAttr(type, value);
276  } else {
277  auto vecTy = cast<ShapedType>(type);
278  attr = SplatElementsAttr::get(vecTy, value);
279  }
280 
281  return builder.create<arith::ConstantOp>(loc, attr);
282 }
283 
285  Type type, int64_t value) {
286  unsigned elementBitWidth = 0;
287  if (auto intTy = dyn_cast<IntegerType>(type))
288  elementBitWidth = intTy.getWidth();
289  else
290  elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
291 
292  return createScalarOrSplatConstant(builder, loc, type,
293  APInt(elementBitWidth, value));
294 }
295 
297  Type type, const APFloat &value) {
298  if (isa<FloatType>(type))
299  return builder.createOrFold<arith::ConstantOp>(
300  loc, type, builder.getFloatAttr(type, value));
301  TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
302  return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303 }
304 
306  if (auto value = dyn_cast_if_present<Value>(ofr))
307  return value.getType();
308  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309  return attr.getType();
310 }
311 
313  return b.create<arith::AndIOp>(loc, lhs, rhs);
314 }
316  if (isa<FloatType>(lhs.getType()))
317  return b.create<arith::AddFOp>(loc, lhs, rhs);
318  return b.create<arith::AddIOp>(loc, lhs, rhs);
319 }
321  if (isa<FloatType>(lhs.getType()))
322  return b.create<arith::SubFOp>(loc, lhs, rhs);
323  return b.create<arith::SubIOp>(loc, lhs, rhs);
324 }
326  if (isa<FloatType>(lhs.getType()))
327  return b.create<arith::MulFOp>(loc, lhs, rhs);
328  return b.create<arith::MulIOp>(loc, lhs, rhs);
329 }
331  if (isa<FloatType>(lhs.getType()))
332  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
334 }
336  if (isa<FloatType>(lhs.getType()))
337  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
339 }
341  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
342 }
343 
344 namespace mlir::arith {
345 
347  return createProduct(builder, loc, values, values.front().getType());
348 }
349 
351  Type resultType) {
352  Value one = builder.create<ConstantOp>(loc, resultType,
353  builder.getOneAttr(resultType));
354  ArithBuilder arithBuilder(builder, loc);
355  return std::accumulate(
356  values.begin(), values.end(), one,
357  [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358 }
359 
360 /// Map strings to float types.
361 std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362  Builder b(ctx);
364  .Case("f4E2M1FN", b.getFloat4E2M1FNType())
365  .Case("f6E2M3FN", b.getFloat6E2M3FNType())
366  .Case("f6E3M2FN", b.getFloat6E3M2FNType())
367  .Case("f8E5M2", b.getFloat8E5M2Type())
368  .Case("f8E4M3", b.getFloat8E4M3Type())
369  .Case("f8E4M3FN", b.getFloat8E4M3FNType())
370  .Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
371  .Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
372  .Case("f8E3M4", b.getFloat8E3M4Type())
373  .Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
374  .Case("bf16", b.getBF16Type())
375  .Case("f16", b.getF16Type())
376  .Case("f32", b.getF32Type())
377  .Case("f64", b.getF64Type())
378  .Case("f80", b.getF80Type())
379  .Case("f128", b.getF128Type())
380  .Default(std::nullopt);
381 }
382 
383 } // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition: Utils.cpp:187
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition: Utils.cpp:141
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition: Utils.cpp:167
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
FloatType getFloat8E5M2Type()
Definition: Builders.cpp:49
FloatType getF80Type()
Definition: Builders.cpp:91
FloatType getF128Type()
Definition: Builders.cpp:93
FloatType getFloat8E8M0FNUType()
Definition: Builders.cpp:77
FloatType getF32Type()
Definition: Builders.cpp:87
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
FloatType getFloat6E3M2FNType()
Definition: Builders.cpp:45
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:294
FloatType getFloat8E3M4Type()
Definition: Builders.cpp:73
FloatType getFloat8E4M3Type()
Definition: Builders.cpp:53
FloatType getF16Type()
Definition: Builders.cpp:83
FloatType getBF16Type()
Definition: Builders.cpp:81
FloatType getFloat4E2M1FNType()
Definition: Builders.cpp:37
FloatType getFloat8E4M3FNType()
Definition: Builders.cpp:57
FloatType getFloat6E2M3FNType()
Definition: Builders.cpp:41
FloatType getFloat8E4M3FNUZType()
Definition: Builders.cpp:65
FloatType getFloat8E5M2FNUZType()
Definition: Builders.cpp:61
FloatType getF64Type()
Definition: Builders.cpp:89
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:382
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:529
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:76
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition: Utils.cpp:361
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:346
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:87
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:103
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:325
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:312
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:335
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:340
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:315
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:330
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:320
The matcher that matches a certain kind of op.
Definition: Matchers.h:283