MLIR  20.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "llvm/ADT/SmallBitVector.h"
19 #include <numeric>
20 
21 using namespace mlir;
22 
23 std::optional<SmallVector<OpFoldResult>>
25  ShapedType expandedType,
26  ArrayRef<ReassociationIndices> reassociation,
27  ArrayRef<OpFoldResult> inputShape) {
28 
29  SmallVector<Value> outputShapeValues;
30  SmallVector<int64_t> outputShapeInts;
31  // For zero-rank inputs, all dims in result shape are unit extent.
32  if (inputShape.empty()) {
33  outputShapeInts.resize(expandedType.getRank(), 1);
34  return getMixedValues(outputShapeInts, outputShapeValues, b);
35  }
36 
37  // Check for all static shapes.
38  if (expandedType.hasStaticShape()) {
39  ArrayRef<int64_t> staticShape = expandedType.getShape();
40  outputShapeInts.assign(staticShape.begin(), staticShape.end());
41  return getMixedValues(outputShapeInts, outputShapeValues, b);
42  }
43 
44  outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45  for (const auto &it : llvm::enumerate(reassociation)) {
46  ReassociationIndices indexGroup = it.value();
47 
48  int64_t indexGroupStaticSizesProductInt = 1;
49  bool foundDynamicShape = false;
50  for (int64_t index : indexGroup) {
51  int64_t outputDimSize = expandedType.getDimSize(index);
52  // Cannot infer expanded shape with multiple dynamic dims in the
53  // same reassociation group!
54  if (ShapedType::isDynamic(outputDimSize)) {
55  if (foundDynamicShape)
56  return std::nullopt;
57  foundDynamicShape = true;
58  } else {
59  outputShapeInts[index] = outputDimSize;
60  indexGroupStaticSizesProductInt *= outputDimSize;
61  }
62  }
63  if (!foundDynamicShape)
64  continue;
65 
66  int64_t inputIndex = it.index();
67  // Call get<Value>() under the assumption that we're not casting
68  // dynamism.
69  Value indexGroupSize = inputShape[inputIndex].get<Value>();
70  Value indexGroupStaticSizesProduct =
71  b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt);
72  Value dynamicDimSize = b.createOrFold<arith::DivUIOp>(
73  loc, indexGroupSize, indexGroupStaticSizesProduct);
74  outputShapeValues.push_back(dynamicDimSize);
75  }
76 
77  if ((int64_t)outputShapeValues.size() !=
78  llvm::count(outputShapeInts, ShapedType::kDynamic))
79  return std::nullopt;
80 
81  return getMixedValues(outputShapeInts, outputShapeValues, b);
82 }
83 
84 /// Matches a ConstantIndexOp.
85 /// TODO: This should probably just be a general matcher that uses matchConstant
86 /// and checks the operation for an index type.
89 }
90 
91 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
92  ArrayRef<int64_t> shape) {
93  llvm::SmallBitVector dimsToProject(shape.size());
94  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95  if (shape[pos] == 1) {
96  dimsToProject.set(pos);
97  --rank;
98  }
99  }
100  return dimsToProject;
101 }
102 
104  OpFoldResult ofr) {
105  if (auto value = dyn_cast_if_present<Value>(ofr))
106  return value;
107  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108  return b.create<arith::ConstantOp>(
109  loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110 }
111 
113  OpFoldResult ofr) {
114  if (auto value = dyn_cast_if_present<Value>(ofr))
115  return value;
116  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
117  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
118 }
119 
121  Type targetType, Value value) {
122  if (targetType == value.getType())
123  return value;
124 
125  bool targetIsIndex = targetType.isIndex();
126  bool valueIsIndex = value.getType().isIndex();
127  if (targetIsIndex ^ valueIsIndex)
128  return b.create<arith::IndexCastOp>(loc, targetType, value);
129 
130  auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131  auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
132  assert(targetIntegerType && valueIntegerType &&
133  "unexpected cast between types other than integers and index");
134  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135 
136  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
138  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
139 }
140 
142  IntegerType toType, bool isUnsigned) {
143  // If operand is floating point, cast directly to the int type.
144  if (isa<FloatType>(operand.getType())) {
145  if (isUnsigned)
146  return b.create<arith::FPToUIOp>(toType, operand);
147  return b.create<arith::FPToSIOp>(toType, operand);
148  }
149  // Cast index operands directly to the int type.
150  if (operand.getType().isIndex())
151  return b.create<arith::IndexCastOp>(toType, operand);
152  if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
153  // Either extend or truncate.
154  if (toType.getWidth() > fromIntType.getWidth()) {
155  if (isUnsigned)
156  return b.create<arith::ExtUIOp>(toType, operand);
157  return b.create<arith::ExtSIOp>(toType, operand);
158  }
159  if (toType.getWidth() < fromIntType.getWidth())
160  return b.create<arith::TruncIOp>(toType, operand);
161  return operand;
162  }
163 
164  return {};
165 }
166 
168  FloatType toType, bool isUnsigned) {
169  // If operand is integer, cast directly to the float type.
170  // Note that it is unclear how to cast from BF16<->FP16.
171  if (isa<IntegerType>(operand.getType())) {
172  if (isUnsigned)
173  return b.create<arith::UIToFPOp>(toType, operand);
174  return b.create<arith::SIToFPOp>(toType, operand);
175  }
176  if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
177  if (toType.getWidth() > fromFpTy.getWidth())
178  return b.create<arith::ExtFOp>(toType, operand);
179  if (toType.getWidth() < fromFpTy.getWidth())
180  return b.create<arith::TruncFOp>(toType, operand);
181  return operand;
182  }
183 
184  return {};
185 }
186 
188  ComplexType targetType,
189  bool isUnsigned) {
190  if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
191  if (isa<FloatType>(targetType.getElementType()) &&
192  isa<FloatType>(fromComplexType.getElementType())) {
193  Value real = b.create<complex::ReOp>(operand);
194  Value imag = b.create<complex::ImOp>(operand);
195  Type targetETy = targetType.getElementType();
196  if (targetType.getElementType().getIntOrFloatBitWidth() <
197  fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198  real = b.create<arith::TruncFOp>(targetETy, real);
199  imag = b.create<arith::TruncFOp>(targetETy, imag);
200  } else {
201  real = b.create<arith::ExtFOp>(targetETy, real);
202  imag = b.create<arith::ExtFOp>(targetETy, imag);
203  }
204  return b.create<complex::CreateOp>(targetType, real, imag);
205  }
206  }
207 
208  if (dyn_cast<FloatType>(operand.getType())) {
209  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210  auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211  Value from = operand;
212  if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
213  from = b.create<arith::ExtFOp>(toFpTy, from);
214  }
215  if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
216  from = b.create<arith::TruncFOp>(toFpTy, from);
217  }
219  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
220  return b.create<complex::CreateOp>(targetType, from, zero);
221  }
222 
223  if (dyn_cast<IntegerType>(operand.getType())) {
224  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225  Value from = operand;
226  if (isUnsigned) {
227  from = b.create<arith::UIToFPOp>(toFpTy, from);
228  } else {
229  from = b.create<arith::SIToFPOp>(toFpTy, from);
230  }
232  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
233  return b.create<complex::CreateOp>(targetType, from, zero);
234  }
235 
236  return {};
237 }
238 
240  Type toType, bool isUnsignedCast) {
241  if (operand.getType() == toType)
242  return operand;
243  ImplicitLocOpBuilder ib(loc, b);
244  Value result;
245  if (auto intTy = dyn_cast<IntegerType>(toType)) {
246  result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
247  } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
248  result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
249  } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
250  result =
251  convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
252  }
253 
254  if (result)
255  return result;
256 
257  emitWarning(loc) << "could not cast operand of type " << operand.getType()
258  << " to " << toType;
259  return operand;
260 }
261 
264  ArrayRef<OpFoldResult> valueOrAttrVec) {
265  return llvm::to_vector<4>(
266  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
267  return getValueOrCreateConstantIndexOp(b, loc, value);
268  }));
269 }
270 
272  Type type, const APInt &value) {
273  TypedAttr attr;
274  if (isa<IntegerType>(type)) {
275  attr = builder.getIntegerAttr(type, value);
276  } else {
277  auto vecTy = cast<ShapedType>(type);
278  attr = SplatElementsAttr::get(vecTy, value);
279  }
280 
281  return builder.create<arith::ConstantOp>(loc, attr);
282 }
283 
285  Type type, int64_t value) {
286  unsigned elementBitWidth = 0;
287  if (auto intTy = dyn_cast<IntegerType>(type))
288  elementBitWidth = intTy.getWidth();
289  else
290  elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
291 
292  return createScalarOrSplatConstant(builder, loc, type,
293  APInt(elementBitWidth, value));
294 }
295 
297  Type type, const APFloat &value) {
298  if (isa<FloatType>(type))
299  return builder.createOrFold<arith::ConstantOp>(
300  loc, type, builder.getFloatAttr(type, value));
301  TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
302  return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303 }
304 
306  if (auto value = dyn_cast_if_present<Value>(ofr))
307  return value.getType();
308  auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309  return attr.getType();
310 }
311 
313  return b.create<arith::AndIOp>(loc, lhs, rhs);
314 }
316  if (isa<FloatType>(lhs.getType()))
317  return b.create<arith::AddFOp>(loc, lhs, rhs);
318  return b.create<arith::AddIOp>(loc, lhs, rhs);
319 }
321  if (isa<FloatType>(lhs.getType()))
322  return b.create<arith::SubFOp>(loc, lhs, rhs);
323  return b.create<arith::SubIOp>(loc, lhs, rhs);
324 }
326  if (isa<FloatType>(lhs.getType()))
327  return b.create<arith::MulFOp>(loc, lhs, rhs);
328  return b.create<arith::MulIOp>(loc, lhs, rhs);
329 }
331  if (isa<FloatType>(lhs.getType()))
332  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
334 }
336  if (isa<FloatType>(lhs.getType()))
337  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
339 }
341  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
342 }
343 
344 namespace mlir::arith {
345 
347  return createProduct(builder, loc, values, values.front().getType());
348 }
349 
351  Type resultType) {
352  Value one = builder.create<ConstantOp>(loc, resultType,
353  builder.getOneAttr(resultType));
354  ArithBuilder arithBuilder(builder, loc);
355  return std::accumulate(
356  values.begin(), values.end(), one,
357  [&arithBuilder](Value acc, Value v) { return arithBuilder.mul(acc, v); });
358 }
359 
360 } // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition: Utils.cpp:187
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition: Utils.cpp:141
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition: Utils.cpp:167
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:242
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:265
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:353
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:210
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:57
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:75
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition: Utils.cpp:346
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:271
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:103
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:120
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition: Utils.cpp:24
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:91
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:87
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition: Utils.h:103
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:325
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:312
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:335
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:340
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:315
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:330
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:320
The matcher that matches a certain kind of op.
Definition: Matchers.h:224