29 for (
const auto &en :
enumerate(type.getShape())) {
31 if (ShapedType::isDynamic(en.value()))
40 return b.
create<PadOp>(loc, type, source, low, high, pad, nofold);
46 auto tensorTy = cast<RankedTensorType>(rankedTensor.
getType());
49 if (en.value() == ShapedType::kDynamic)
50 dynamicDims.push_back(
51 b.
create<tensor::DimOp>(loc, rankedTensor, en.index()));
59 if (transposeVector.empty())
60 return rankedTensorType;
63 transposeVector.size() !=
static_cast<size_t>(rankedTensorType.getRank()))
67 rankedTensorType.getShape().end());
71 RankedTensorType transposedTensorType =
72 RTTBuilder(rankedTensorType).
setShape(transposedShape);
73 return transposedTensorType;
77 llvm::SmallBitVector droppedDims = op.getDroppedDims();
81 for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
83 if (droppedDims.test(resultDim)) {
87 op.getSource(), op.
getResult(), srcDim, resultDim);
88 if (
failed(equalDimSize) || !*equalDimSize)
97 llvm::SmallBitVector droppedDims = op.getDroppedDims();
98 int64_t resultDim = 0;
101 RankedTensorType sourceType = op.getSourceType();
102 for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
103 if (droppedDims.test(dim)) {
106 if (sourceType.getDimSize(dim) != 1)
111 op.getSource(), op.
getResult(), dim, resultDim);
112 if (
failed(equalDimSize) || !*equalDimSize)
Base type for affine expression.
IntegerAttr getIndexAttr(int64_t value)
MLIRContext * getContext() const
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
static FailureOr< bool > areEqual(Value value1, Value value2, std::optional< int64_t > dim1=std::nullopt, std::optional< int64_t > dim2=std::nullopt)
Compute whether the given values/dimensions are equal.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
PadOp createPadHighOp(RankedTensorType type, Value source, Value pad, bool nofold, Location loc, OpBuilder &builder)
FailureOr< RankedTensorType > computeTransposedType(RankedTensorType rankedTensorType, ArrayRef< int64_t > transposeVector)
Returns the transposed rankedTensorType if transposeVector is non-empty.
This header declares functions that assist transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.