17 #include "llvm/ADT/SmallBitVector.h"
22 std::optional<SmallVector<OpFoldResult>>
24 ShapedType expandedType,
31 if (inputShape.empty()) {
32 outputShapeInts.resize(expandedType.getRank(), 1);
37 if (expandedType.hasStaticShape()) {
39 outputShapeInts.assign(staticShape.begin(), staticShape.end());
43 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
47 int64_t indexGroupStaticSizesProductInt = 1;
48 bool foundDynamicShape =
false;
49 for (int64_t index : indexGroup) {
50 int64_t outputDimSize = expandedType.getDimSize(index);
53 if (ShapedType::isDynamic(outputDimSize)) {
54 if (foundDynamicShape)
56 foundDynamicShape =
true;
58 outputShapeInts[index] = outputDimSize;
59 indexGroupStaticSizesProductInt *= outputDimSize;
62 if (!foundDynamicShape)
65 int64_t inputIndex = it.index();
68 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
69 Value indexGroupStaticSizesProduct =
72 loc, indexGroupSize, indexGroupStaticSizesProduct);
73 outputShapeValues.push_back(dynamicDimSize);
76 if ((int64_t)outputShapeValues.size() !=
77 llvm::count(outputShapeInts, ShapedType::kDynamic))
92 llvm::SmallBitVector dimsToProject(shape.size());
93 for (
unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
94 if (shape[pos] == 1) {
95 dimsToProject.set(pos);
104 if (
auto value = dyn_cast_if_present<Value>(ofr))
106 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
107 return arith::ConstantOp::create(
108 b, loc, b.
getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
113 if (
auto value = dyn_cast_if_present<Value>(ofr))
115 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
121 if (targetType == value.
getType())
124 bool targetIsIndex = targetType.
isIndex();
126 if (targetIsIndex ^ valueIsIndex)
127 return arith::IndexCastOp::create(b, loc, targetType, value);
129 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
130 auto valueIntegerType = dyn_cast<IntegerType>(value.
getType());
131 assert(targetIntegerType && valueIntegerType &&
132 "unexpected cast between types other than integers and index");
133 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
135 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
136 return arith::ExtSIOp::create(b, loc, targetIntegerType, value);
137 return arith::TruncIOp::create(b, loc, targetIntegerType, value);
141 IntegerType toType,
bool isUnsigned) {
143 if (isa<FloatType>(operand.
getType())) {
145 return arith::FPToUIOp::create(b, toType, operand);
146 return arith::FPToSIOp::create(b, toType, operand);
150 return arith::IndexCastOp::create(b, toType, operand);
151 if (
auto fromIntType = dyn_cast<IntegerType>(operand.
getType())) {
153 if (toType.getWidth() > fromIntType.getWidth()) {
155 return arith::ExtUIOp::create(b, toType, operand);
156 return arith::ExtSIOp::create(b, toType, operand);
158 if (toType.getWidth() < fromIntType.getWidth())
159 return arith::TruncIOp::create(b, toType, operand);
167 FloatType toType,
bool isUnsigned) {
170 if (isa<IntegerType>(operand.
getType())) {
172 return arith::UIToFPOp::create(b, toType, operand);
173 return arith::SIToFPOp::create(b, toType, operand);
175 if (
auto fromFpTy = dyn_cast<FloatType>(operand.
getType())) {
176 if (toType.getWidth() > fromFpTy.getWidth())
177 return arith::ExtFOp::create(b, toType, operand);
178 if (toType.getWidth() < fromFpTy.getWidth())
179 return arith::TruncFOp::create(b, toType, operand);
187 ComplexType targetType,
189 if (
auto fromComplexType = dyn_cast<ComplexType>(operand.
getType())) {
190 if (isa<FloatType>(targetType.getElementType()) &&
191 isa<FloatType>(fromComplexType.getElementType())) {
192 Value real = complex::ReOp::create(b, operand);
193 Value imag = complex::ImOp::create(b, operand);
194 Type targetETy = targetType.getElementType();
195 if (targetType.getElementType().getIntOrFloatBitWidth() <
196 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
197 real = arith::TruncFOp::create(b, targetETy, real);
198 imag = arith::TruncFOp::create(b, targetETy, imag);
200 real = arith::ExtFOp::create(b, targetETy, real);
201 imag = arith::ExtFOp::create(b, targetETy, imag);
203 return complex::CreateOp::create(b, targetType, real, imag);
207 if (isa<FloatType>(operand.
getType())) {
208 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
209 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
210 Value from = operand;
212 from = arith::ExtFOp::create(b, toFpTy, from);
215 from = arith::TruncFOp::create(b, toFpTy, from);
218 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
219 return complex::CreateOp::create(b, targetType, from, zero);
222 if (isa<IntegerType>(operand.
getType())) {
223 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
224 Value from = operand;
226 from = arith::UIToFPOp::create(b, toFpTy, from);
228 from = arith::SIToFPOp::create(b, toFpTy, from);
231 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
232 return complex::CreateOp::create(b, targetType, from, zero);
239 Type toType,
bool isUnsignedCast) {
240 if (operand.
getType() == toType)
244 if (
auto intTy = dyn_cast<IntegerType>(toType)) {
246 }
else if (
auto floatTy = dyn_cast<FloatType>(toType)) {
248 }
else if (
auto complexTy = dyn_cast<ComplexType>(toType)) {
264 return llvm::to_vector<4>(
271 Type type,
const APInt &value) {
273 if (isa<IntegerType>(type)) {
276 auto vecTy = cast<ShapedType>(type);
280 return arith::ConstantOp::create(builder, loc, attr);
284 Type type, int64_t value) {
285 unsigned elementBitWidth = 0;
286 if (
auto intTy = dyn_cast<IntegerType>(type))
287 elementBitWidth = intTy.getWidth();
289 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
292 APInt(elementBitWidth, value));
296 Type type,
const APFloat &value) {
297 if (isa<FloatType>(type))
301 return builder.
createOrFold<arith::ConstantOp>(loc, type, splat);
305 if (
auto value = dyn_cast_if_present<Value>(ofr))
306 return value.getType();
307 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
308 return attr.getType();
312 return arith::AndIOp::create(b, loc, lhs, rhs);
315 if (isa<FloatType>(lhs.
getType()))
316 return arith::AddFOp::create(b, loc, lhs, rhs);
317 return arith::AddIOp::create(b, loc, lhs, rhs, ovf);
320 if (isa<FloatType>(lhs.
getType()))
321 return arith::SubFOp::create(b, loc, lhs, rhs);
322 return arith::SubIOp::create(b, loc, lhs, rhs, ovf);
325 if (isa<FloatType>(lhs.
getType()))
326 return arith::MulFOp::create(b, loc, lhs, rhs);
327 return arith::MulIOp::create(b, loc, lhs, rhs, ovf);
330 if (isa<FloatType>(lhs.
getType()))
331 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs);
332 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs);
335 if (isa<FloatType>(lhs.
getType()))
336 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs);
337 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs);
340 return arith::SelectOp::create(b, loc, cmp, lhs, rhs);
346 return createProduct(builder, loc, values, values.front().getType());
351 Value one = ConstantOp::create(builder, loc, resultType,
354 return std::accumulate(
355 values.begin(), values.end(), one,
356 [&arithBuilder](
Value acc,
Value v) { return arithBuilder.mul(acc, v); });
363 .Case(
"f4E2M1FN", b.
getType<Float4E2M1FNType>())
364 .Case(
"f6E2M3FN", b.
getType<Float6E2M3FNType>())
365 .Case(
"f6E3M2FN", b.
getType<Float6E3M2FNType>())
366 .Case(
"f8E5M2", b.
getType<Float8E5M2Type>())
367 .Case(
"f8E4M3", b.
getType<Float8E4M3Type>())
368 .Case(
"f8E4M3FN", b.
getType<Float8E4M3FNType>())
369 .Case(
"f8E5M2FNUZ", b.
getType<Float8E5M2FNUZType>())
370 .Case(
"f8E4M3FNUZ", b.
getType<Float8E4M3FNUZType>())
371 .Case(
"f8E3M4", b.
getType<Float8E3M4Type>())
372 .Case(
"f8E8M0FNU", b.
getType<Float8E8M0FNUType>())
373 .Case(
"bf16", b.
getType<BFloat16Type>())
374 .Case(
"f16", b.
getType<Float16Type>())
375 .Case(
"f32", b.
getType<Float32Type>())
376 .Case(
"f64", b.
getType<Float64Type>())
377 .Case(
"f80", b.
getType<Float80Type>())
378 .Case(
"f128", b.
getType<Float128Type>())
379 .Default(std::nullopt);
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value mul(Value lhs, Value rhs)
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Value select(Value cmp, Value lhs, Value rhs)
Value add(Value lhs, Value rhs)
Value sgt(Value lhs, Value rhs)
Value sub(Value lhs, Value rhs)
The matcher that matches a certain kind of op.