MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "llvm/ADT/SmallBitVector.h"
18#include <numeric>
19
20using namespace mlir;
21
22std::optional<SmallVector<OpFoldResult>>
24 ShapedType expandedType,
26 ArrayRef<OpFoldResult> inputShape) {
27
28 SmallVector<Value> outputShapeValues;
29 SmallVector<int64_t> outputShapeInts;
30 // For zero-rank inputs, all dims in result shape are unit extent.
31 if (inputShape.empty()) {
32 outputShapeInts.resize(expandedType.getRank(), 1);
33 return getMixedValues(outputShapeInts, outputShapeValues, b);
34 }
35
36 // Check for all static shapes.
37 if (expandedType.hasStaticShape()) {
38 ArrayRef<int64_t> staticShape = expandedType.getShape();
39 outputShapeInts.assign(staticShape.begin(), staticShape.end());
40 return getMixedValues(outputShapeInts, outputShapeValues, b);
41 }
42
43 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
44 for (const auto &it : llvm::enumerate(reassociation)) {
45 ReassociationIndices indexGroup = it.value();
46
47 int64_t indexGroupStaticSizesProductInt = 1;
48 bool foundDynamicShape = false;
49 for (int64_t index : indexGroup) {
50 int64_t outputDimSize = expandedType.getDimSize(index);
51 // Cannot infer expanded shape with multiple dynamic dims in the
52 // same reassociation group!
53 if (ShapedType::isDynamic(outputDimSize)) {
54 if (foundDynamicShape)
55 return std::nullopt;
56 foundDynamicShape = true;
57 } else {
58 outputShapeInts[index] = outputDimSize;
59 indexGroupStaticSizesProductInt *= outputDimSize;
60 }
61 }
62 if (!foundDynamicShape)
63 continue;
64
65 int64_t inputIndex = it.index();
66 // Call get<Value>() under the assumption that we're not casting
67 // dynamism.
68 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
69 Value indexGroupStaticSizesProduct =
70 arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
71 Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
72 loc, indexGroupSize, indexGroupStaticSizesProduct);
73 outputShapeValues.push_back(dynamicDimSize);
74 }
75
76 if ((int64_t)outputShapeValues.size() !=
77 llvm::count(outputShapeInts, ShapedType::kDynamic))
78 return std::nullopt;
79
80 return getMixedValues(outputShapeInts, outputShapeValues, b);
81}
82
83/// Matches a ConstantIndexOp.
84/// TODO: This should probably just be a general matcher that uses matchConstant
85/// and checks the operation for an index type.
89
90llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
92 llvm::SmallBitVector dimsToProject(shape.size());
93 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
94 if (shape[pos] == 1) {
95 dimsToProject.set(pos);
96 --rank;
97 }
98 }
99 return dimsToProject;
100}
101
103 OpFoldResult ofr) {
104 if (auto value = dyn_cast_if_present<Value>(ofr))
105 return value;
106 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
107 return arith::ConstantOp::create(
108 b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
109}
110
112 OpFoldResult ofr) {
113 if (auto value = dyn_cast_if_present<Value>(ofr))
114 return value;
115 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
116 return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
117}
118
120 Type targetType, Value value) {
121 if (targetType == value.getType())
122 return value;
123
124 bool targetIsIndex = targetType.isIndex();
125 bool valueIsIndex = value.getType().isIndex();
126 if (targetIsIndex ^ valueIsIndex)
127 return arith::IndexCastOp::create(b, loc, targetType, value);
128
129 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
130 auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
131 assert(targetIntegerType && valueIntegerType &&
132 "unexpected cast between types other than integers and index");
133 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
134
135 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
136 return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
137 return arith::TruncIOp::create(b, loc, targetIntegerType, value);
138}
139
141 IntegerType toType, bool isUnsigned) {
142 // If operand is floating point, cast directly to the int type.
143 if (isa<FloatType>(operand.getType())) {
144 if (isUnsigned)
145 return arith::FPToUIOp::create(b, toType, operand);
146 return arith::FPToSIOp::create(b, toType, operand);
147 }
148 // Cast index operands directly to the int type.
149 if (operand.getType().isIndex())
150 return arith::IndexCastOp::create(b, toType, operand);
151 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
152 // Either extend or truncate.
153 if (toType.getWidth() > fromIntType.getWidth()) {
154 if (isUnsigned)
155 return arith::ExtUIOp::create(b, toType, operand);
156 return arith::ExtSIOp::create(b, toType, operand);
157 }
158 if (toType.getWidth() < fromIntType.getWidth())
159 return arith::TruncIOp::create(b, toType, operand);
160 return operand;
161 }
162
163 return {};
164}
165
167 FloatType toType, bool isUnsigned) {
168 // If operand is integer, cast directly to the float type.
169 // Note that it is unclear how to cast from BF16<->FP16.
170 if (isa<IntegerType>(operand.getType())) {
171 if (isUnsigned)
172 return arith::UIToFPOp::create(b, toType, operand);
173 return arith::SIToFPOp::create(b, toType, operand);
174 }
175 if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
176 if (toType.getWidth() > fromFpTy.getWidth())
177 return arith::ExtFOp::create(b, toType, operand);
178 if (toType.getWidth() < fromFpTy.getWidth())
179 return arith::TruncFOp::create(b, toType, operand);
180 return operand;
181 }
182
183 return {};
184}
185
187 ComplexType targetType,
188 bool isUnsigned) {
189 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
190 if (isa<FloatType>(targetType.getElementType()) &&
191 isa<FloatType>(fromComplexType.getElementType())) {
192 Value real = complex::ReOp::create(b, operand);
193 Value imag = complex::ImOp::create(b, operand);
194 Type targetETy = targetType.getElementType();
195 if (targetType.getElementType().getIntOrFloatBitWidth() <
196 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
197 real = arith::TruncFOp::create(b, targetETy, real);
198 imag = arith::TruncFOp::create(b, targetETy, imag);
199 } else {
200 real = arith::ExtFOp::create(b, targetETy, real);
201 imag = arith::ExtFOp::create(b, targetETy, imag);
202 }
203 return complex::CreateOp::create(b, targetType, real, imag);
204 }
205 }
206
207 if (isa<FloatType>(operand.getType())) {
208 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
209 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
210 Value from = operand;
211 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
212 from = arith::ExtFOp::create(b, toFpTy, from);
213 }
214 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
215 from = arith::TruncFOp::create(b, toFpTy, from);
216 }
218 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
219 return complex::CreateOp::create(b, targetType, from, zero);
220 }
221
222 if (isa<IntegerType>(operand.getType())) {
223 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
224 Value from = operand;
225 if (isUnsigned) {
226 from = arith::UIToFPOp::create(b, toFpTy, from);
227 } else {
228 from = arith::SIToFPOp::create(b, toFpTy, from);
229 }
231 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
232 return complex::CreateOp::create(b, targetType, from, zero);
233 }
234
235 return {};
236}
237
239 Type toType, bool isUnsignedCast) {
240 if (operand.getType() == toType)
241 return operand;
242 ImplicitLocOpBuilder ib(loc, b);
244 if (auto intTy = dyn_cast<IntegerType>(toType)) {
245 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
246 } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
247 result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
248 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
249 result =
250 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
251 }
252
253 if (result)
254 return result;
255
256 emitWarning(loc) << "could not cast operand of type " << operand.getType()
257 << " to " << toType;
258 return operand;
259}
260
263 ArrayRef<OpFoldResult> valueOrAttrVec) {
264 return llvm::to_vector<4>(
265 llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
266 return getValueOrCreateConstantIndexOp(b, loc, value);
267 }));
268}
269
271 Type type, const APInt &value) {
272 TypedAttr attr;
273 if (isa<IntegerType>(type)) {
274 attr = builder.getIntegerAttr(type, value);
275 } else {
276 auto vecTy = cast<ShapedType>(type);
277 attr = SplatElementsAttr::get(vecTy, value);
278 }
279
280 return arith::ConstantOp::create(builder, loc, attr);
281}
282
284 Type type, int64_t value) {
285 unsigned elementBitWidth = 0;
286 if (auto intTy = dyn_cast<IntegerType>(type))
287 elementBitWidth = intTy.getWidth();
288 else
289 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
290
291 return createScalarOrSplatConstant(builder, loc, type,
292 APInt(elementBitWidth, value));
293}
294
296 Type type, const APFloat &value) {
297 if (isa<FloatType>(type))
298 return builder.createOrFold<arith::ConstantOp>(
299 loc, type, builder.getFloatAttr(type, value));
300 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
301 return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
302}
303
305 if (auto value = dyn_cast_if_present<Value>(ofr))
306 return value.getType();
307 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
308 return attr.getType();
309}
310
312 return arith::AndIOp::create(b, loc, lhs, rhs);
313}
315 if (isa<FloatType>(lhs.getType()))
316 return arith::AddFOp::create(b, loc, lhs, rhs);
317 return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
318}
320 if (isa<FloatType>(lhs.getType()))
321 return arith::SubFOp::create(b, loc, lhs, rhs);
322 return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
323}
325 if (isa<FloatType>(lhs.getType()))
326 return arith::MulFOp::create(b, loc, lhs, rhs);
327 return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
328}
330 if (isa<FloatType>(lhs.getType()))
331 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
332 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
333}
335 if (isa<FloatType>(lhs.getType()))
336 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
337 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
338}
340 return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
341}
342
343namespace mlir::arith {
344
346 return createProduct(builder, loc, values, values.front().getType());
347}
348
350 Type resultType) {
351 Value one = ConstantOp::create(builder, loc, resultType,
352 builder.getOneAttr(resultType));
353 ArithBuilder arithBuilder(builder, loc);
354 return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) {
355 return arithBuilder.mul(acc, v);
356 });
357}
358
359/// Map strings to float types.
360std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
361 Builder b(ctx);
363 .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
364 .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
365 .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
366 .Case("f8E5M2", b.getType<Float8E5M2Type>())
367 .Case("f8E4M3", b.getType<Float8E4M3Type>())
368 .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
369 .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
370 .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
371 .Case("f8E3M4", b.getType<Float8E3M4Type>())
372 .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
373 .Case("bf16", b.getType<BFloat16Type>())
374 .Case("f16", b.getType<Float16Type>())
375 .Case("f32", b.getType<Float32Type>())
376 .Case("f64", b.getType<Float64Type>())
377 .Case("f80", b.getType<Float80Type>())
378 .Case("f128", b.getType<Float128Type>())
379 .Default(std::nullopt);
380}
381
382} // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition Utils.cpp:186
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition Utils.cpp:140
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition Utils.cpp:166
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:342
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:630
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:330
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:360
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition Utils.cpp:345
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:238
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition Utils.cpp:270
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:119
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:90
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition Utils.cpp:86
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value mul(Value lhs, Value rhs)
Definition Utils.cpp:324
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:311
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:334
Value select(Value cmp, Value lhs, Value rhs)
Definition Utils.cpp:339
Value add(Value lhs, Value rhs)
Definition Utils.cpp:314
Value sgt(Value lhs, Value rhs)
Definition Utils.cpp:329
Value sub(Value lhs, Value rhs)
Definition Utils.cpp:319
The matcher that matches a certain kind of op.
Definition Matchers.h:283