18 #include "llvm/ADT/SmallBitVector.h"
23 std::optional<SmallVector<OpFoldResult>>
25 ShapedType expandedType,
32 if (inputShape.empty()) {
33 outputShapeInts.resize(expandedType.getRank(), 1);
38 if (expandedType.hasStaticShape()) {
40 outputShapeInts.assign(staticShape.begin(), staticShape.end());
44 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
48 int64_t indexGroupStaticSizesProductInt = 1;
49 bool foundDynamicShape =
false;
50 for (int64_t index : indexGroup) {
51 int64_t outputDimSize = expandedType.getDimSize(index);
54 if (ShapedType::isDynamic(outputDimSize)) {
55 if (foundDynamicShape)
57 foundDynamicShape =
true;
59 outputShapeInts[index] = outputDimSize;
60 indexGroupStaticSizesProductInt *= outputDimSize;
63 if (!foundDynamicShape)
66 int64_t inputIndex = it.index();
69 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
70 Value indexGroupStaticSizesProduct =
73 loc, indexGroupSize, indexGroupStaticSizesProduct);
74 outputShapeValues.push_back(dynamicDimSize);
77 if ((int64_t)outputShapeValues.size() !=
78 llvm::count(outputShapeInts, ShapedType::kDynamic))
93 llvm::SmallBitVector dimsToProject(shape.size());
94 for (
unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
95 if (shape[pos] == 1) {
96 dimsToProject.set(pos);
100 return dimsToProject;
105 if (
auto value = dyn_cast_if_present<Value>(ofr))
107 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
108 return b.
create<arith::ConstantOp>(
109 loc, b.
getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
114 if (
auto value = dyn_cast_if_present<Value>(ofr))
116 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
122 if (targetType == value.
getType())
125 bool targetIsIndex = targetType.
isIndex();
127 if (targetIsIndex ^ valueIsIndex)
128 return b.
create<arith::IndexCastOp>(loc, targetType, value);
130 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
131 auto valueIntegerType = dyn_cast<IntegerType>(value.
getType());
132 assert(targetIntegerType && valueIntegerType &&
133 "unexpected cast between types other than integers and index");
134 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
136 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
137 return b.
create<arith::ExtSIOp>(loc, targetIntegerType, value);
138 return b.
create<arith::TruncIOp>(loc, targetIntegerType, value);
142 IntegerType toType,
bool isUnsigned) {
144 if (isa<FloatType>(operand.
getType())) {
146 return b.
create<arith::FPToUIOp>(toType, operand);
147 return b.
create<arith::FPToSIOp>(toType, operand);
151 return b.
create<arith::IndexCastOp>(toType, operand);
152 if (
auto fromIntType = dyn_cast<IntegerType>(operand.
getType())) {
154 if (toType.getWidth() > fromIntType.getWidth()) {
156 return b.
create<arith::ExtUIOp>(toType, operand);
157 return b.
create<arith::ExtSIOp>(toType, operand);
159 if (toType.getWidth() < fromIntType.getWidth())
160 return b.
create<arith::TruncIOp>(toType, operand);
171 if (isa<IntegerType>(operand.
getType())) {
173 return b.
create<arith::UIToFPOp>(toType, operand);
174 return b.
create<arith::SIToFPOp>(toType, operand);
176 if (
auto fromFpTy = dyn_cast<FloatType>(operand.
getType())) {
177 if (toType.
getWidth() > fromFpTy.getWidth())
178 return b.
create<arith::ExtFOp>(toType, operand);
179 if (toType.
getWidth() < fromFpTy.getWidth())
180 return b.
create<arith::TruncFOp>(toType, operand);
188 ComplexType targetType,
190 if (
auto fromComplexType = dyn_cast<ComplexType>(operand.
getType())) {
191 if (isa<FloatType>(targetType.getElementType()) &&
192 isa<FloatType>(fromComplexType.getElementType())) {
195 Type targetETy = targetType.getElementType();
196 if (targetType.getElementType().getIntOrFloatBitWidth() <
197 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
198 real = b.
create<arith::TruncFOp>(targetETy, real);
199 imag = b.
create<arith::TruncFOp>(targetETy, imag);
201 real = b.
create<arith::ExtFOp>(targetETy, real);
202 imag = b.
create<arith::ExtFOp>(targetETy, imag);
204 return b.
create<complex::CreateOp>(targetType, real, imag);
208 if (dyn_cast<FloatType>(operand.
getType())) {
209 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
211 Value from = operand;
213 from = b.
create<arith::ExtFOp>(toFpTy, from);
216 from = b.
create<arith::TruncFOp>(toFpTy, from);
220 return b.
create<complex::CreateOp>(targetType, from, zero);
223 if (dyn_cast<IntegerType>(operand.
getType())) {
224 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
225 Value from = operand;
227 from = b.
create<arith::UIToFPOp>(toFpTy, from);
229 from = b.
create<arith::SIToFPOp>(toFpTy, from);
233 return b.
create<complex::CreateOp>(targetType, from, zero);
240 Type toType,
bool isUnsignedCast) {
241 if (operand.
getType() == toType)
245 if (
auto intTy = dyn_cast<IntegerType>(toType)) {
247 }
else if (
auto floatTy = dyn_cast<FloatType>(toType)) {
249 }
else if (
auto complexTy = dyn_cast<ComplexType>(toType)) {
265 return llvm::to_vector<4>(
272 Type type,
const APInt &value) {
274 if (isa<IntegerType>(type)) {
277 auto vecTy = cast<ShapedType>(type);
281 return builder.
create<arith::ConstantOp>(loc, attr);
285 Type type, int64_t value) {
286 unsigned elementBitWidth = 0;
287 if (
auto intTy = dyn_cast<IntegerType>(type))
288 elementBitWidth = intTy.getWidth();
290 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
293 APInt(elementBitWidth, value));
297 Type type,
const APFloat &value) {
298 if (isa<FloatType>(type))
302 return builder.
createOrFold<arith::ConstantOp>(loc, type, splat);
306 if (
auto value = dyn_cast_if_present<Value>(ofr))
307 return value.getType();
308 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
309 return attr.getType();
313 return b.
create<arith::AndIOp>(loc, lhs, rhs);
316 if (isa<FloatType>(lhs.
getType()))
317 return b.
create<arith::AddFOp>(loc, lhs, rhs);
318 return b.
create<arith::AddIOp>(loc, lhs, rhs);
321 if (isa<FloatType>(lhs.
getType()))
322 return b.
create<arith::SubFOp>(loc, lhs, rhs);
323 return b.
create<arith::SubIOp>(loc, lhs, rhs);
326 if (isa<FloatType>(lhs.
getType()))
327 return b.
create<arith::MulFOp>(loc, lhs, rhs);
328 return b.
create<arith::MulIOp>(loc, lhs, rhs);
331 if (isa<FloatType>(lhs.
getType()))
332 return b.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
333 return b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
336 if (isa<FloatType>(lhs.
getType()))
337 return b.
create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
338 return b.
create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
341 return b.
create<arith::SelectOp>(loc, cmp, lhs, rhs);
347 return createProduct(builder, loc, values, values.front().getType());
352 Value one = builder.
create<ConstantOp>(loc, resultType,
355 return std::accumulate(
356 values.begin(), values.end(), one,
357 [&arithBuilder](
Value acc,
Value v) { return arithBuilder.mul(acc, v); });
380 .Default(std::nullopt);
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
This class is a general helper class for creating context-global objects like types,...
FloatType getFloat8E5M2Type()
FloatType getFloat8E8M0FNUType()
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatType getFloat6E3M2FNType()
FloatAttr getFloatAttr(Type type, double value)
FloatType getFloat8E3M4Type()
FloatType getFloat8E4M3Type()
FloatType getFloat4E2M1FNType()
FloatType getFloat8E4M3FNType()
FloatType getFloat6E2M3FNType()
FloatType getFloat8E4M3FNUZType()
FloatType getFloat8E5M2FNUZType()
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Specialization of arith.constant op that returns a floating point value.
Specialization of arith.constant op that returns an integer of index type.
std::optional< FloatType > parseFloatType(MLIRContext *ctx, StringRef name)
Map strings to float types.
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value mul(Value lhs, Value rhs)
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Value select(Value cmp, Value lhs, Value rhs)
Value add(Value lhs, Value rhs)
Value sgt(Value lhs, Value rhs)
Value sub(Value lhs, Value rhs)
The matcher that matches a certain kind of op.