17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/SmallVectorExtras.h"
23std::optional<SmallVector<OpFoldResult>>
25 ShapedType expandedType,
32 if (inputShape.empty()) {
33 outputShapeInts.resize(expandedType.getRank(), 1);
38 if (expandedType.hasStaticShape()) {
40 outputShapeInts.assign(staticShape.begin(), staticShape.end());
44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
45 for (
const auto &it : llvm::enumerate(reassociation)) {
48 int64_t indexGroupStaticSizesProductInt = 1;
49 bool foundDynamicShape =
false;
54 if (ShapedType::isDynamic(outputDimSize)) {
55 if (foundDynamicShape)
57 foundDynamicShape =
true;
59 outputShapeInts[
index] = outputDimSize;
60 indexGroupStaticSizesProductInt *= outputDimSize;
63 if (!foundDynamicShape)
66 int64_t inputIndex = it.index();
69 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
70 Value indexGroupStaticSizesProduct =
72 Value dynamicDimSize =
b.createOrFold<arith::DivSIOp>(
73 loc, indexGroupSize, indexGroupStaticSizesProduct);
74 outputShapeValues.push_back(dynamicDimSize);
77 if ((
int64_t)outputShapeValues.size() !=
78 llvm::count(outputShapeInts, ShapedType::kDynamic))
93 llvm::SmallBitVector dimsToProject(
shape.size());
94 for (
unsigned pos = 0, e =
shape.size(); pos < e && rank > 0; ++pos) {
95 if (
shape[pos] == 1) {
96 dimsToProject.set(pos);
100 return dimsToProject;
105 if (
auto value = dyn_cast_if_present<Value>(ofr))
107 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108 return arith::ConstantOp::create(
109 b, loc,
b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
114 if (
auto value = dyn_cast_if_present<Value>(ofr))
116 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
122 if (targetType == value.
getType())
125 bool targetIsIndex = targetType.
isIndex();
127 if (targetIsIndex ^ valueIsIndex)
128 return arith::IndexCastOp::create(
b, loc, targetType, value);
130 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131 auto valueIntegerType = dyn_cast<IntegerType>(value.
getType());
132 assert(targetIntegerType && valueIntegerType &&
133 "unexpected cast between types other than integers and index");
134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137 return arith::ExtSIOp::create(
b, loc, targetIntegerType, value);
138 return arith::TruncIOp::create(
b, loc, targetIntegerType, value);
142 IntegerType toType,
bool isUnsigned) {
144 if (isa<FloatType>(operand.
getType())) {
146 return arith::FPToUIOp::create(
b, toType, operand);
147 return arith::FPToSIOp::create(
b, toType, operand);
151 return arith::IndexCastOp::create(
b, toType, operand);
152 if (
auto fromIntType = dyn_cast<IntegerType>(operand.
getType())) {
154 if (toType.getWidth() > fromIntType.getWidth()) {
156 return arith::ExtUIOp::create(
b, toType, operand);
157 return arith::ExtSIOp::create(
b, toType, operand);
159 if (toType.getWidth() < fromIntType.getWidth())
160 return arith::TruncIOp::create(
b, toType, operand);
168 FloatType toType,
bool isUnsigned) {
171 if (isa<IntegerType>(operand.
getType())) {
173 return arith::UIToFPOp::create(
b, toType, operand);
174 return arith::SIToFPOp::create(
b, toType, operand);
176 if (
auto fromFpTy = dyn_cast<FloatType>(operand.
getType())) {
177 if (toType.getWidth() > fromFpTy.getWidth())
178 return arith::ExtFOp::create(
b, toType, operand);
179 if (toType.getWidth() < fromFpTy.getWidth())
180 return arith::TruncFOp::create(
b, toType, operand);
188 ComplexType targetType,
190 if (
auto fromComplexType = dyn_cast<ComplexType>(operand.
getType())) {
191 if (isa<FloatType>(targetType.getElementType()) &&
192 isa<FloatType>(fromComplexType.getElementType())) {
193 Value real = complex::ReOp::create(
b, operand);
194 Value imag = complex::ImOp::create(
b, operand);
195 Type targetETy = targetType.getElementType();
196 if (targetType.getElementType().getIntOrFloatBitWidth() <
197 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198 real = arith::TruncFOp::create(
b, targetETy, real);
199 imag = arith::TruncFOp::create(
b, targetETy, imag);
201 real = arith::ExtFOp::create(
b, targetETy, real);
202 imag = arith::ExtFOp::create(
b, targetETy, imag);
204 return complex::CreateOp::create(
b, targetType, real, imag);
208 if (isa<FloatType>(operand.
getType())) {
209 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
210 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
211 Value from = operand;
213 from = arith::ExtFOp::create(
b, toFpTy, from);
216 from = arith::TruncFOp::create(
b, toFpTy, from);
219 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
220 return complex::CreateOp::create(
b, targetType, from, zero);
223 if (isa<IntegerType>(operand.
getType())) {
224 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225 Value from = operand;
227 from = arith::UIToFPOp::create(
b, toFpTy, from);
229 from = arith::SIToFPOp::create(
b, toFpTy, from);
232 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
233 return complex::CreateOp::create(
b, targetType, from, zero);
240 Type toType,
bool isUnsignedCast) {
241 if (operand.
getType() == toType)
245 if (
auto intTy = dyn_cast<IntegerType>(toType)) {
247 }
else if (
auto floatTy = dyn_cast<FloatType>(toType)) {
249 }
else if (
auto complexTy = dyn_cast<ComplexType>(toType)) {
265 return llvm::map_to_vector<4>(
272 Type type,
const APInt &value) {
274 if (isa<IntegerType>(type)) {
277 auto vecTy = cast<ShapedType>(type);
281 return arith::ConstantOp::create(builder, loc, attr);
286 unsigned elementBitWidth = 0;
287 if (
auto intTy = dyn_cast<IntegerType>(type))
288 elementBitWidth = intTy.getWidth();
290 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
293 APInt(elementBitWidth, value));
297 Type type,
const APFloat &value) {
298 if (isa<FloatType>(type))
302 return builder.
createOrFold<arith::ConstantOp>(loc, type, splat);
306 if (
auto value = dyn_cast_if_present<Value>(ofr))
307 return value.getType();
308 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309 return attr.getType();
313 return arith::AndIOp::create(b, loc,
lhs,
rhs);
316 if (isa<FloatType>(
lhs.getType()))
317 return arith::AddFOp::create(b, loc,
lhs,
rhs);
318 return arith::AddIOp::create(b, loc,
lhs,
rhs, ovf);
321 if (isa<FloatType>(
lhs.getType()))
322 return arith::SubFOp::create(b, loc,
lhs,
rhs);
323 return arith::SubIOp::create(b, loc,
lhs,
rhs, ovf);
326 if (isa<FloatType>(
lhs.getType()))
327 return arith::MulFOp::create(b, loc,
lhs,
rhs);
328 return arith::MulIOp::create(b, loc,
lhs,
rhs, ovf);
331 if (isa<FloatType>(
lhs.getType()))
332 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT,
lhs,
rhs);
333 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt,
lhs,
rhs);
336 if (isa<FloatType>(
lhs.getType()))
337 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT,
lhs,
rhs);
338 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,
lhs,
rhs);
341 return arith::SelectOp::create(b, loc, cmp,
lhs,
rhs);
347 return createProduct(builder, loc, values, values.front().getType());
352 Value one = ConstantOp::create(builder, loc, resultType,
355 return llvm::accumulate(values, one, [&arithBuilder](
Value acc,
Value v) {
356 return arithBuilder.
mul(
acc, v);
364 .Case(
"f4E2M1FN",
b.getType<Float4E2M1FNType>())
365 .Case(
"f6E2M3FN",
b.getType<Float6E2M3FNType>())
366 .Case(
"f6E3M2FN",
b.getType<Float6E3M2FNType>())
367 .Case(
"f8E5M2",
b.getType<Float8E5M2Type>())
368 .Case(
"f8E4M3",
b.getType<Float8E4M3Type>())
369 .Case(
"f8E4M3FN",
b.getType<Float8E4M3FNType>())
370 .Case(
"f8E5M2FNUZ",
b.getType<Float8E5M2FNUZType>())
371 .Case(
"f8E4M3FNUZ",
b.getType<Float8E4M3FNUZType>())
372 .Case(
"f8E3M4",
b.getType<Float8E3M4Type>())
373 .Case(
"f8E8M0FNU",
b.getType<Float8E8M0FNUType>())
374 .Case(
"bf16",
b.getType<BFloat16Type>())
375 .Case(
"f16",
b.getType<Float16Type>())
376 .Case(
"f32",
b.getType<Float32Type>())
377 .Case(
"f64",
b.getType<Float64Type>())
378 .Case(
"f80",
b.getType<Float80Type>())
379 .Case(
"f128",
b.getType<Float128Type>())
380 .Default(std::nullopt);
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
SmallVector< int64_t, 2 > ReassociationIndices
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value mul(Value lhs, Value rhs)
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Value select(Value cmp, Value lhs, Value rhs)
Value add(Value lhs, Value rhs)
Value sgt(Value lhs, Value rhs)
Value sub(Value lhs, Value rhs)
The matcher that matches a certain kind of op.