MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the Linalg dialect.
10//
11//===----------------------------------------------------------------------===//
12
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include <numeric>
20
21using namespace mlir;
22
23std::optional<SmallVector<OpFoldResult>>
25 ShapedType expandedType,
27 ArrayRef<OpFoldResult> inputShape) {
28
29 SmallVector<Value> outputShapeValues;
30 SmallVector<int64_t> outputShapeInts;
31 // For zero-rank inputs, all dims in result shape are unit extent.
32 if (inputShape.empty()) {
33 outputShapeInts.resize(expandedType.getRank(), 1);
34 return getMixedValues(outputShapeInts, outputShapeValues, b);
35 }
36
37 // Check for all static shapes.
38 if (expandedType.hasStaticShape()) {
39 ArrayRef<int64_t> staticShape = expandedType.getShape();
40 outputShapeInts.assign(staticShape.begin(), staticShape.end());
41 return getMixedValues(outputShapeInts, outputShapeValues, b);
42 }
43
44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45 for (const auto &it : llvm::enumerate(reassociation)) {
46 ReassociationIndices indexGroup = it.value();
47
48 int64_t indexGroupStaticSizesProductInt = 1;
49 bool foundDynamicShape = false;
50 for (int64_t index : indexGroup) {
51 int64_t outputDimSize = expandedType.getDimSize(index);
52 // Cannot infer expanded shape with multiple dynamic dims in the
53 // same reassociation group!
54 if (ShapedType::isDynamic(outputDimSize)) {
55 if (foundDynamicShape)
56 return std::nullopt;
57 foundDynamicShape = true;
58 } else {
59 outputShapeInts[index] = outputDimSize;
60 indexGroupStaticSizesProductInt *= outputDimSize;
61 }
62 }
63 if (!foundDynamicShape)
64 continue;
65
66 int64_t inputIndex = it.index();
67 // Call get<Value>() under the assumption that we're not casting
68 // dynamism.
69 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
70 Value indexGroupStaticSizesProduct =
71 arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt);
72 Value dynamicDimSize = b.createOrFold<arith::DivSIOp>(
73 loc, indexGroupSize, indexGroupStaticSizesProduct);
74 outputShapeValues.push_back(dynamicDimSize);
75 }
76
77 if ((int64_t)outputShapeValues.size() !=
78 llvm::count(outputShapeInts, ShapedType::kDynamic))
79 return std::nullopt;
80
81 return getMixedValues(outputShapeInts, outputShapeValues, b);
82}
83
84/// Matches a ConstantIndexOp.
85/// TODO: This should probably just be a general matcher that uses matchConstant
86/// and checks the operation for an index type.
90
91llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
93 llvm::SmallBitVector dimsToProject(shape.size());
94 for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95 if (shape[pos] == 1) {
96 dimsToProject.set(pos);
97 --rank;
98 }
99 }
100 return dimsToProject;
101}
102
104 OpFoldResult ofr) {
105 if (auto value = dyn_cast_if_present<Value>(ofr))
106 return value;
107 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108 return arith::ConstantOp::create(
109 b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
110}
111
113 OpFoldResult ofr) {
114 if (auto value = dyn_cast_if_present<Value>(ofr))
115 return value;
116 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
117 return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue());
118}
119
121 Type targetType, Value value) {
122 if (targetType == value.getType())
123 return value;
124
125 bool targetIsIndex = targetType.isIndex();
126 bool valueIsIndex = value.getType().isIndex();
127 if (targetIsIndex ^ valueIsIndex)
128 return arith::IndexCastOp::create(b, loc, targetType, value);
129
130 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131 auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
132 assert(targetIntegerType && valueIntegerType &&
133 "unexpected cast between types other than integers and index");
134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135
136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137 return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
138 return arith::TruncIOp::create(b, loc, targetIntegerType, value);
139}
140
142 IntegerType toType, bool isUnsigned) {
143 // If operand is floating point, cast directly to the int type.
144 if (isa<FloatType>(operand.getType())) {
145 if (isUnsigned)
146 return arith::FPToUIOp::create(b, toType, operand);
147 return arith::FPToSIOp::create(b, toType, operand);
148 }
149 // Cast index operands directly to the int type.
150 if (operand.getType().isIndex())
151 return arith::IndexCastOp::create(b, toType, operand);
152 if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
153 // Either extend or truncate.
154 if (toType.getWidth() > fromIntType.getWidth()) {
155 if (isUnsigned)
156 return arith::ExtUIOp::create(b, toType, operand);
157 return arith::ExtSIOp::create(b, toType, operand);
158 }
159 if (toType.getWidth() < fromIntType.getWidth())
160 return arith::TruncIOp::create(b, toType, operand);
161 return operand;
162 }
163
164 return {};
165}
166
168 FloatType toType, bool isUnsigned) {
169 // If operand is integer, cast directly to the float type.
170 // Note that it is unclear how to cast from BF16<->FP16.
171 if (isa<IntegerType>(operand.getType())) {
172 if (isUnsigned)
173 return arith::UIToFPOp::create(b, toType, operand);
174 return arith::SIToFPOp::create(b, toType, operand);
175 }
176 if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
177 if (toType.getWidth() > fromFpTy.getWidth())
178 return arith::ExtFOp::create(b, toType, operand);
179 if (toType.getWidth() < fromFpTy.getWidth())
180 return arith::TruncFOp::create(b, toType, operand);
181 return operand;
182 }
183
184 return {};
185}
186
188 ComplexType targetType,
189 bool isUnsigned) {
190 if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
191 if (isa<FloatType>(targetType.getElementType()) &&
192 isa<FloatType>(fromComplexType.getElementType())) {
193 Value real = complex::ReOp::create(b, operand);
194 Value imag = complex::ImOp::create(b, operand);
195 Type targetETy = targetType.getElementType();
196 if (targetType.getElementType().getIntOrFloatBitWidth() <
197 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198 real = arith::TruncFOp::create(b, targetETy, real);
199 imag = arith::TruncFOp::create(b, targetETy, imag);
200 } else {
201 real = arith::ExtFOp::create(b, targetETy, real);
202 imag = arith::ExtFOp::create(b, targetETy, imag);
203 }
204 return complex::CreateOp::create(b, targetType, real, imag);
205 }
206 }
207
208 if (isa<FloatType>(operand.getType())) {
209 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211 Value from = operand;
212 if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
213 from = arith::ExtFOp::create(b, toFpTy, from);
214 }
215 if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
216 from = arith::TruncFOp::create(b, toFpTy, from);
217 }
219 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
220 return complex::CreateOp::create(b, targetType, from, zero);
221 }
222
223 if (isa<IntegerType>(operand.getType())) {
224 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225 Value from = operand;
226 if (isUnsigned) {
227 from = arith::UIToFPOp::create(b, toFpTy, from);
228 } else {
229 from = arith::SIToFPOp::create(b, toFpTy, from);
230 }
232 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
233 return complex::CreateOp::create(b, targetType, from, zero);
234 }
235
236 return {};
237}
238
240 Type toType, bool isUnsignedCast) {
241 if (operand.getType() == toType)
242 return operand;
243 ImplicitLocOpBuilder ib(loc, b);
245 if (auto intTy = dyn_cast<IntegerType>(toType)) {
246 result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
247 } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
248 result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
249 } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
250 result =
251 convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
252 }
253
254 if (result)
255 return result;
256
257 emitWarning(loc) << "could not cast operand of type " << operand.getType()
258 << " to " << toType;
259 return operand;
260}
261
264 ArrayRef<OpFoldResult> valueOrAttrVec) {
265 return llvm::map_to_vector<4>(
266 valueOrAttrVec, [&](OpFoldResult value) -> Value {
267 return getValueOrCreateConstantIndexOp(b, loc, value);
268 });
269}
270
272 Type type, const APInt &value) {
273 TypedAttr attr;
274 if (isa<IntegerType>(type)) {
275 attr = builder.getIntegerAttr(type, value);
276 } else {
277 auto vecTy = cast<ShapedType>(type);
278 attr = SplatElementsAttr::get(vecTy, value);
279 }
280
281 return arith::ConstantOp::create(builder, loc, attr);
282}
283
285 Type type, int64_t value) {
286 unsigned elementBitWidth = 0;
287 if (auto intTy = dyn_cast<IntegerType>(type))
288 elementBitWidth = intTy.getWidth();
289 else
290 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
291
292 return createScalarOrSplatConstant(builder, loc, type,
293 APInt(elementBitWidth, value));
294}
295
297 Type type, const APFloat &value) {
298 if (isa<FloatType>(type))
299 return builder.createOrFold<arith::ConstantOp>(
300 loc, type, builder.getFloatAttr(type, value));
301 TypedAttr splat = SplatElementsAttr::get(cast<ShapedType>(type), value);
302 return builder.createOrFold<arith::ConstantOp>(loc, type, splat);
303}
304
306 if (auto value = dyn_cast_if_present<Value>(ofr))
307 return value.getType();
308 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309 return attr.getType();
310}
311
313 return arith::AndIOp::create(b, loc, lhs, rhs);
314}
316 if (isa<FloatType>(lhs.getType()))
317 return arith::AddFOp::create(b, loc, lhs, rhs);
318 return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
319}
321 if (isa<FloatType>(lhs.getType()))
322 return arith::SubFOp::create(b, loc, lhs, rhs);
323 return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
324}
326 if (isa<FloatType>(lhs.getType()))
327 return arith::MulFOp::create(b, loc, lhs, rhs);
328 return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
329}
331 if (isa<FloatType>(lhs.getType()))
332 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
333 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
334}
336 if (isa<FloatType>(lhs.getType()))
337 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
338 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
339}
341 return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
342}
343
344namespace mlir::arith {
345
347 return createProduct(builder, loc, values, values.front().getType());
348}
349
351 Type resultType) {
352 Value one = ConstantOp::create(builder, loc, resultType,
353 builder.getOneAttr(resultType));
354 ArithBuilder arithBuilder(builder, loc);
355 return llvm::accumulate(values, one, [&arithBuilder](Value acc, Value v) {
356 return arithBuilder.mul(acc, v);
357 });
358}
359
360/// Map strings to float types.
361std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
362 Builder b(ctx);
364 .Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
365 .Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
366 .Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
367 .Case("f8E5M2", b.getType<Float8E5M2Type>())
368 .Case("f8E4M3", b.getType<Float8E4M3Type>())
369 .Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
370 .Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
371 .Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
372 .Case("f8E3M4", b.getType<Float8E3M4Type>())
373 .Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
374 .Case("bf16", b.getType<BFloat16Type>())
375 .Case("f16", b.getType<Float16Type>())
376 .Case("f32", b.getType<Float32Type>())
377 .Case("f64", b.getType<Float64Type>())
378 .Case("f80", b.getType<Float80Type>())
379 .Case("f128", b.getType<Float128Type>())
380 .Default(std::nullopt);
381}
382
383} // namespace mlir::arith
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition Utils.cpp:187
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition Utils.cpp:141
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition Utils.cpp:167
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
TypedAttr getOneAttr(Type type)
Definition Builders.cpp:346
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
Definition ArithOps.cpp:334
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Definition Utils.cpp:361
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Definition Utils.cpp:346
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition Utils.cpp:239
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition Utils.cpp:271
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:103
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition Utils.cpp:120
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:112
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:24
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition Utils.cpp:91
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition Utils.cpp:87
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
Helper struct to build simple arithmetic quantities with minimal type inference support.
Definition Utils.h:103
Value mul(Value lhs, Value rhs)
Definition Utils.cpp:325
Value _and(Value lhs, Value rhs)
Definition Utils.cpp:312
Value slt(Value lhs, Value rhs)
Definition Utils.cpp:335
Value select(Value cmp, Value lhs, Value rhs)
Definition Utils.cpp:340
Value add(Value lhs, Value rhs)
Definition Utils.cpp:315
Value sgt(Value lhs, Value rhs)
Definition Utils.cpp:330
Value sub(Value lhs, Value rhs)
Definition Utils.cpp:320
The matcher that matches a certain kind of op.
Definition Matchers.h:283