19#include "llvm/ADT/SmallBitVector.h"
20#include "llvm/ADT/SmallVectorExtras.h"
25std::optional<SmallVector<OpFoldResult>>
27 ShapedType expandedType,
34 if (inputShape.empty()) {
35 outputShapeInts.resize(expandedType.getRank(), 1);
40 if (expandedType.hasStaticShape()) {
42 outputShapeInts.assign(staticShape.begin(), staticShape.end());
46 outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic);
47 for (
const auto &it : llvm::enumerate(reassociation)) {
50 int64_t indexGroupStaticSizesProductInt = 1;
51 bool foundDynamicShape =
false;
56 if (ShapedType::isDynamic(outputDimSize)) {
57 if (foundDynamicShape)
59 foundDynamicShape =
true;
61 outputShapeInts[
index] = outputDimSize;
62 indexGroupStaticSizesProductInt *= outputDimSize;
65 if (!foundDynamicShape)
68 int64_t inputIndex = it.index();
71 Value indexGroupSize = cast<Value>(inputShape[inputIndex]);
72 Value indexGroupStaticSizesProduct =
74 Value dynamicDimSize =
b.createOrFold<arith::DivSIOp>(
75 loc, indexGroupSize, indexGroupStaticSizesProduct);
76 outputShapeValues.push_back(dynamicDimSize);
79 if ((
int64_t)outputShapeValues.size() !=
80 llvm::count(outputShapeInts, ShapedType::kDynamic))
95 llvm::SmallBitVector dimsToProject(
shape.size());
96 for (
unsigned pos = 0, e =
shape.size(); pos < e && rank > 0; ++pos) {
97 if (
shape[pos] == 1) {
98 dimsToProject.set(pos);
102 return dimsToProject;
107 if (
auto value = dyn_cast_if_present<Value>(ofr))
109 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
110 return arith::ConstantOp::create(
111 b, loc,
b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue()));
116 if (
auto value = dyn_cast_if_present<Value>(ofr))
118 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
124 if (targetType == value.
getType())
127 bool targetIsIndex = targetType.
isIndex();
129 if (targetIsIndex ^ valueIsIndex)
130 return arith::IndexCastOp::create(
b, loc, targetType, value);
132 auto targetIntegerType = dyn_cast<IntegerType>(targetType);
133 auto valueIntegerType = dyn_cast<IntegerType>(value.
getType());
134 assert(targetIntegerType && valueIntegerType &&
135 "unexpected cast between types other than integers and index");
136 assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
138 if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
139 return arith::ExtSIOp::create(
b, loc, targetIntegerType, value);
140 return arith::TruncIOp::create(
b, loc, targetIntegerType, value);
144 IntegerType toType,
bool isUnsigned) {
146 if (isa<FloatType>(operand.
getType())) {
148 return arith::FPToUIOp::create(
b, toType, operand);
149 return arith::FPToSIOp::create(
b, toType, operand);
153 return arith::IndexCastOp::create(
b, toType, operand);
154 if (
auto fromIntType = dyn_cast<IntegerType>(operand.
getType())) {
156 if (toType.getWidth() > fromIntType.getWidth()) {
158 return arith::ExtUIOp::create(
b, toType, operand);
159 return arith::ExtSIOp::create(
b, toType, operand);
161 if (toType.getWidth() < fromIntType.getWidth())
162 return arith::TruncIOp::create(
b, toType, operand);
170 FloatType toType,
bool isUnsigned) {
173 if (isa<IntegerType>(operand.
getType())) {
175 return arith::UIToFPOp::create(
b, toType, operand);
176 return arith::SIToFPOp::create(
b, toType, operand);
178 if (
auto fromFpTy = dyn_cast<FloatType>(operand.
getType())) {
179 if (toType.getWidth() > fromFpTy.getWidth())
180 return arith::ExtFOp::create(
b, toType, operand);
181 if (toType.getWidth() < fromFpTy.getWidth())
182 return arith::TruncFOp::create(
b, toType, operand);
190 ComplexType targetType,
192 if (
auto fromComplexType = dyn_cast<ComplexType>(operand.
getType())) {
193 if (isa<FloatType>(targetType.getElementType()) &&
194 isa<FloatType>(fromComplexType.getElementType())) {
195 Value real = complex::ReOp::create(
b, operand);
196 Value imag = complex::ImOp::create(
b, operand);
197 Type targetETy = targetType.getElementType();
198 if (targetType.getElementType().getIntOrFloatBitWidth() <
199 fromComplexType.getElementType().getIntOrFloatBitWidth()) {
200 real = arith::TruncFOp::create(
b, targetETy, real);
201 imag = arith::TruncFOp::create(
b, targetETy, imag);
203 real = arith::ExtFOp::create(
b, targetETy, real);
204 imag = arith::ExtFOp::create(
b, targetETy, imag);
206 return complex::CreateOp::create(
b, targetType, real, imag);
210 if (isa<FloatType>(operand.
getType())) {
211 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
212 auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
213 Value from = operand;
215 from = arith::ExtFOp::create(
b, toFpTy, from);
218 from = arith::TruncFOp::create(
b, toFpTy, from);
221 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
222 return complex::CreateOp::create(
b, targetType, from, zero);
225 if (isa<IntegerType>(operand.
getType())) {
226 FloatType toFpTy = cast<FloatType>(targetType.getElementType());
227 Value from = operand;
229 from = arith::UIToFPOp::create(
b, toFpTy, from);
231 from = arith::SIToFPOp::create(
b, toFpTy, from);
234 b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0));
235 return complex::CreateOp::create(
b, targetType, from, zero);
242 Type toType,
bool isUnsignedCast) {
243 if (operand.
getType() == toType)
247 if (
auto intTy = dyn_cast<IntegerType>(toType)) {
249 }
else if (
auto floatTy = dyn_cast<FloatType>(toType)) {
251 }
else if (
auto complexTy = dyn_cast<ComplexType>(toType)) {
267 return llvm::map_to_vector<4>(
274 Type type,
const APInt &value) {
276 if (isa<IntegerType>(type)) {
279 auto vecTy = cast<ShapedType>(type);
283 return arith::ConstantOp::create(builder, loc, attr);
288 unsigned elementBitWidth = 0;
289 if (
auto intTy = dyn_cast<IntegerType>(type))
290 elementBitWidth = intTy.getWidth();
292 elementBitWidth = cast<ShapedType>(type).getElementTypeBitWidth();
295 APInt(elementBitWidth, value));
299 Type type,
const APFloat &value) {
300 if (isa<FloatType>(type))
304 return builder.
createOrFold<arith::ConstantOp>(loc, type, splat);
308 if (
auto value = dyn_cast_if_present<Value>(ofr))
309 return value.getType();
310 auto attr = cast<IntegerAttr>(cast<Attribute>(ofr));
311 return attr.getType();
315 return arith::AndIOp::create(b, loc,
lhs,
rhs);
318 if (isa<FloatType>(
lhs.getType()))
319 return arith::AddFOp::create(b, loc,
lhs,
rhs);
320 return arith::AddIOp::create(b, loc,
lhs,
rhs, ovf);
323 if (isa<FloatType>(
lhs.getType()))
324 return arith::SubFOp::create(b, loc,
lhs,
rhs);
325 return arith::SubIOp::create(b, loc,
lhs,
rhs, ovf);
328 if (isa<FloatType>(
lhs.getType()))
329 return arith::MulFOp::create(b, loc,
lhs,
rhs);
330 return arith::MulIOp::create(b, loc,
lhs,
rhs, ovf);
333 if (isa<FloatType>(
lhs.getType()))
334 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT,
lhs,
rhs);
335 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt,
lhs,
rhs);
338 if (isa<FloatType>(
lhs.getType()))
339 return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT,
lhs,
rhs);
340 return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt,
lhs,
rhs);
343 return arith::SelectOp::create(b, loc, cmp,
lhs,
rhs);
349 return createProduct(builder, loc, values, values.front().getType());
354 Value one = ConstantOp::create(builder, loc, resultType,
357 return llvm::accumulate(values, one, [&arithBuilder](
Value acc,
Value v) {
358 return arithBuilder.
mul(
acc, v);
365 if (!name.empty() && name.front() ==
'!')
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantFloatOp create(OpBuilder &builder, Location location, FloatType type, const APFloat &value)
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
FloatType parseFloatType(MLIRContext *ctx, StringRef name)
Value createProduct(OpBuilder &builder, Location loc, ArrayRef< Value > values)
Include the generated interface declarations.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
SmallVector< int64_t, 2 > ReassociationIndices
Helper struct to build simple arithmetic quantities with minimal type inference support.
Value mul(Value lhs, Value rhs)
Value _and(Value lhs, Value rhs)
Value slt(Value lhs, Value rhs)
Value select(Value cmp, Value lhs, Value rhs)
Value add(Value lhs, Value rhs)
Value sgt(Value lhs, Value rhs)
Value sub(Value lhs, Value rhs)
The matcher that matches a certain kind of op.