33#include "llvm/ADT/DenseSet.h"
34#include "llvm/ADT/STLExtras.h"
35#include "llvm/ADT/SmallBitVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/Support/Casting.h"
38#include "llvm/Support/MathExtras.h"
49 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
51 if (complex::ConstantOp::isBuildableWith(value, type))
52 return complex::ConstantOp::create(builder, loc, type,
53 llvm::cast<ArrayAttr>(value));
59 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (
int64_t sz : tensorType.getShape())
101 mixedSizes.push_back(
b.getIndexAttr(sz));
106 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
116 if (failed(destination))
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
145 isa<Attribute>(size.value()) &&
146 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
205 return RankedTensorType::get(staticShape, type.getElementType(),
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(),
b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(
b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
233 using OpRewritePattern<BitcastOp>::OpRewritePattern;
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
236 PatternRewriter &rewriter)
const final {
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (ShapedType::isStatic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
358 if (llvm::isa<BlockArgument>(opOperand.get()))
360 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
361 return castOp && canFoldIntoConsumerOp(castOp);
368 newOperands.reserve(op->getNumOperands());
374 for (
OpOperand &opOperand : op->getOpOperands()) {
375 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
377 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
378 if (op.isDpsInit(&opOperand) &&
379 !llvm::isa<MemRefType>(newOperands.back().getType()))
380 newResTy[dpsInitIdx++] = newOperands.back().getType();
390 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
392 operand.set(castOp.getOperand());
400 if (inputs.size() != 1 || outputs.size() != 1)
402 Type a = inputs.front(),
b = outputs.front();
403 auto aT = llvm::dyn_cast<TensorType>(a);
404 auto bT = llvm::dyn_cast<TensorType>(
b);
408 if (aT.getElementType() != bT.getElementType())
425 if (rank != two.getRank())
430 for (
int64_t i = 0; i < rank; ++i) {
431 if (one.isDynamicDim(i)) {
432 join.push_back(two.getDimSize(i));
435 if (two.isDynamicDim(i)) {
436 join.push_back(one.getDimSize(i));
439 if (one.getDimSize(i) != two.getDimSize(i))
441 join.push_back(one.getDimSize(i));
451 using OpRewritePattern<CastOp>::OpRewritePattern;
453 LogicalResult matchAndRewrite(CastOp tensorCast,
454 PatternRewriter &rewriter)
const final {
455 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
457 if (!tensorCastOperand)
461 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
462 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
463 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
477 auto newJoin =
joinShapes(sourceType, resultType);
478 if (firstJoin != newJoin)
481 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
482 tensorCastOperand.getOperand());
500 using OpRewritePattern<CastOp>::OpRewritePattern;
502 LogicalResult matchAndRewrite(CastOp tensorCast,
503 PatternRewriter &rewriter)
const final {
504 auto extractOperand =
505 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
508 auto rankedResultType =
509 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
510 if (!rankedResultType)
514 rankedResultType.getShape() ==
515 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
519 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
521 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
523 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
524 if (dimMask && dimMask->count(i))
526 int64_t dim = rankedResultType.getShape()[dimIndex++];
527 if (ShapedType::isDynamic(dim))
529 sizes[i] = rewriter.getIndexAttr(dim);
532 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
533 tensorCast, rankedResultType, extractOperand.getSource(),
534 extractOperand.getMixedOffsets(), sizes,
535 extractOperand.getMixedStrides());
544 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
551RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
552 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
554 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
555 return llvm::cast<RankedTensorType>(type);
557 int64_t concatRank = tensorTypes[0].getRank();
560 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
563 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
567 for (
auto tensorType : tensorTypes)
572 for (
auto tensorType : tensorTypes)
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.
getTypes());
583 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
584 build(builder,
result, *resultType, dim, inputs);
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
592 for (
auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
600 return emitOpError(
"rank of concatenated inputs must match result rank");
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
606 return emitOpError(
"inputs and result element type must match");
609 if (dim >= resultRank)
610 return emitOpError(
"concatenation dim must be less than the tensor rank");
613 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
617 for (
auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
621 return emitOpError(
"static concatenation size mismatch along ")
622 <<
"non-concatenated dimension " << i;
628 for (
auto tensorType : inputTypes)
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
635 for (
auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
641 << resultType <<
"does not match inferred shape "
642 << inferredResultType <<
" static sizes";
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
653 inputShapes.reserve(numInputs);
655 concatOffsets.reserve(numInputs);
662 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
669 concatOffsets.push_back(outputShape[concatDim]);
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
674 inputShapes.emplace_back(std::move(inputShape));
684 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[
index];
686 auto insertSlice = tensor::InsertSliceOp::create(
697ConcatOp::reifyResultShapes(
OpBuilder &builder,
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
703 Value init = inputs[0];
711 for (
int64_t i = 0; i < rank; ++i) {
714 if (!
getType().isDynamicDim(i)) {
716 }
else if (!inferredResultType.isDynamicDim(i)) {
719 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
726 if (
getType().isDynamicDim(dim)) {
731 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
734 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742 reifiedReturnShapes[0][dim] =
748void ConcatOp::getAsmResultNames(
750 setNameFn(getResult(),
"concat");
755 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter)
const override {
767 if (concatOp.getInputs().size() != 1)
770 concatOp.getInputs()[0]);
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter)
const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
804 LogicalResult matched = failure();
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (
auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter)
const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
859 concatOp.getResultType())) {
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
877 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886 setNameFn(getResult(),
"dim");
891 auto loc =
result.location;
893 build(builder,
result, source, indexValue);
896std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
906 if (!rankedSourceType)
909 if (rankedSourceType.getRank() <= constantIndex)
917 setResultRange(getResult(),
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
939 if (!tensorType.isDynamicDim(
index.getInt())) {
944 Operation *definingOp = getSource().getDefiningOp();
947 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
952 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
957 if (ShapedType::isDynamic(dim))
960 return Value{*dynExtents};
964 unsigned unsignedIndex =
index.getValue().getZExtValue();
966 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
985 using OpRewritePattern<DimOp>::OpRewritePattern;
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter)
const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
992 Value newSource = castOp.getOperand();
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter)
const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter)
const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1034 Location loc = dim.getLoc();
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.
getType() != dim.getType())
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1048 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1063void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder,
result, tensorType, dynamicSizes);
1070void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1076 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1079LogicalResult EmptyOp::verify() {
1081 return emitOpError(
"incorrect number of dynamic sizes, has ")
1083 <<
getType().getNumDynamicDims();
1088EmptyOp::reifyResultShapes(OpBuilder &builder,
1090 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1092 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1093 if (
getType().isDynamicDim(i)) {
1102Value EmptyOp::getDynamicSize(
unsigned idx) {
1103 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1105 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1106 if (
getType().isDynamicDim(i))
1111SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1112 SmallVector<OpFoldResult>
result;
1115 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1116 if (
getType().isDynamicDim(i)) {
1137struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1138 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1140 LogicalResult matchAndRewrite(EmptyOp op,
1141 PatternRewriter &rewriter)
const override {
1142 SmallVector<Value> foldedDynamicSizes;
1144 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1147 if (foldedTensorType == op.getType())
1150 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1151 foldedDynamicSizes);
1157struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1158 using OpRewritePattern<DimOp>::OpRewritePattern;
1160 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1161 PatternRewriter &rewriter)
const override {
1162 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1163 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1164 if (!emptyTensorOp || !maybeConstantIndex)
1166 auto emptyTensorType = emptyTensorOp.getType();
1167 if (*maybeConstantIndex < 0 ||
1168 *maybeConstantIndex >= emptyTensorType.getRank() ||
1169 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1172 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1192struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1193 using OpRewritePattern<CastOp>::OpRewritePattern;
1195 LogicalResult matchAndRewrite(CastOp castOp,
1196 PatternRewriter &rewriter)
const override {
1199 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1204 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1205 ArrayRef<int64_t> resultShape = resultType.getShape();
1206 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1207 SmallVector<OpFoldResult> newMixedSizes;
1208 newMixedSizes.reserve(currMixedSizes.size());
1209 assert(resultShape.size() == currMixedSizes.size() &&
1210 "mismatch in result shape and sizes of empty op");
1211 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1212 int64_t newDim = std::get<0>(it);
1213 OpFoldResult currDim = std::get<1>(it);
1216 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1217 if (ShapedType::isDynamic(newDim) ||
1218 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1223 producer,
"mismatch in static value of shape of empty tensor "
1224 "result and cast result");
1226 newMixedSizes.push_back(attr);
1232 if (ShapedType::isStatic(newDim)) {
1233 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1239 newMixedSizes.push_back(currDim);
1244 resultType.getElementType());
1251void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1252 MLIRContext *context) {
1253 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1254 ReplaceEmptyTensorStaticShapeDims>(context);
1271struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1272 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1274 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1275 PatternRewriter &rewriter)
const final {
1276 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1279 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1282 extract, tensorCast.getSource(), extract.getIndices());
1297struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1298 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1300 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1301 PatternRewriter &rewriter)
const final {
1303 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1306 if (!collapseOp.getSrcType().hasStaticShape())
1309 auto sourceSizes = collapseOp.getSrcType().getShape();
1311 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1312 extractOp.getIndices().end());
1313 SmallVector<Value> sourceIndices;
1314 for (
auto [index, group] :
1315 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1316 assert(!group.empty() &&
"association indices groups cannot be empty");
1317 auto groupSize = group.size();
1319 if (groupSize == 1) {
1320 sourceIndices.push_back(index);
1324 SmallVector<int64_t> basis =
1325 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1326 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1327 rewriter, extractOp.getLoc(), index, basis,
true);
1328 llvm::append_range(sourceIndices,
delinearize.getResults());
1330 if (collapseOp.getReassociationIndices().empty()) {
1333 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1335 rewriter, extractOp.getLoc(), zeroAffineMap,
1336 ArrayRef<OpFoldResult>{});
1337 for (int64_t i = 0; i < srcRank; i++) {
1338 sourceIndices.push_back(
1344 extractOp, collapseOp.getSrc(), sourceIndices);
1351void ExtractOp::getAsmResultNames(
1353 setNameFn(getResult(),
"extracted");
1356LogicalResult ExtractOp::verify() {
1358 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1359 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1360 return emitOpError(
"incorrect number of indices for extract_element");
1369 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1374 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1375 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1376 return insertOp.getScalar();
1381OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1382 if (Attribute tensor = adaptor.getTensor()) {
1385 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1386 return splatTensor.getSplatValue<Attribute>();
1389 if (isa<DenseResourceElementsAttr>(tensor))
1394 SmallVector<uint64_t, 8>
indices;
1395 for (Attribute indice : adaptor.getIndices()) {
1396 if (!indice || !llvm::isa<IntegerAttr>(indice))
1398 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1402 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1403 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1404 auto rank = tensorType.getRank();
1405 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1409 for (
int i = rank - 1; i >= 0; --i) {
1410 flatIndex +=
indices[i] * stride;
1411 stride *= tensorType.getDimSize(i);
1415 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1418 return fromElementsOp.getElements()[flatIndex];
1422 if (Attribute tensor = adaptor.getTensor()) {
1423 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1424 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1425 return elementsAttr.getValues<Attribute>()[
indices];
1434void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1435 MLIRContext *context) {
1436 results.
add<ExtractFromTensorCast>(context);
1448void FromElementsOp::getAsmResultNames(
1450 setNameFn(getResult(),
"from_elements");
1455 assert(!elements.empty() &&
"expected at least one element");
1456 Type resultType = RankedTensorType::get(
1457 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1458 build(builder,
result, resultType, elements);
1461OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1462 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1485struct ExtractElementFromIndexCast
1486 :
public OpRewritePattern<tensor::ExtractOp> {
1487 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1489 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1490 PatternRewriter &rewriter)
const final {
1491 Location loc = extract.getLoc();
1492 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1498 auto newExtract = tensor::ExtractOp::create(
1499 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1510void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1511 MLIRContext *context) {
1512 results.
add<ExtractElementFromIndexCast>(context);
1519void GatherOp::getAsmResultNames(
1521 setNameFn(getResult(),
"gather");
1536RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1537 RankedTensorType indicesType,
1538 ArrayRef<int64_t> gatherDims,
1540 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1541 resultShape.reserve(resultShape.size() + sourceType.getRank());
1542 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1543 if (llvm::binary_search(gatherDims, idx)) {
1545 resultShape.push_back(1);
1548 resultShape.push_back(sourceType.getDimSize(idx));
1550 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1556 StringRef gatherOrScatter, StringRef sourceOrDest) {
1558 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1560 int64_t numGatherDims = dims.size();
1561 if (numGatherDims > rank)
1563 <<
"_dims overflow " << sourceOrDest <<
" rank";
1566 <<
"_dims length must match the size of last dimension of indices";
1570 <<
"_dims value must be non-negative";
1573 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1575 for (
int64_t i = 1; i < numGatherDims; ++i) {
1576 if (dims[i - 1] >= dims[i])
1578 <<
"_dims values must be strictly increasing";
1583LogicalResult GatherOp::verify() {
1584 int64_t sourceRank = getSourceType().getRank();
1585 ArrayRef<int64_t> gatherDims = getGatherDims();
1587 getIndicesType().
getShape(), sourceRank,
1588 "gather",
"source")))
1591 RankedTensorType expectedResultType = GatherOp::inferResultType(
1592 getSourceType(), getIndicesType(), gatherDims,
false);
1593 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1594 getSourceType(), getIndicesType(), gatherDims,
true);
1595 if (getResultType() != expectedResultType &&
1596 getResultType() != expectedRankReducedResultType) {
1600 << expectedResultType <<
" or its rank-reduced variant "
1601 << expectedRankReducedResultType <<
" (got: " << getResultType()
1608OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1609 if (OpFoldResult reshapedSource = reshapeConstantSource(
1610 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1612 return reshapedSource;
1620void InsertOp::getAsmResultNames(
1622 setNameFn(getResult(),
"inserted");
1625LogicalResult InsertOp::verify() {
1627 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1628 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1629 return emitOpError(
"incorrect number of indices");
1633OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1634 Attribute scalar = adaptor.getScalar();
1635 Attribute dest = adaptor.getDest();
1637 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1638 if (scalar == splatDest.getSplatValue<Attribute>())
1647void GenerateOp::getAsmResultNames(
1649 setNameFn(getResult(),
"generated");
1652LogicalResult GenerateOp::reifyResultShapes(
1654 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1656 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1657 if (
getType().isDynamicDim(dim)) {
1658 reifiedReturnShapes[0][dim] = getOperand(idx++);
1660 reifiedReturnShapes[0][dim] =
1667LogicalResult GenerateOp::verify() {
1670 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1671 if (getNumOperands() != resultType.getNumDynamicDims())
1672 return emitError(
"must have as many index operands as dynamic extents "
1673 "in the result type");
1677LogicalResult GenerateOp::verifyRegions() {
1678 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1680 if (!llvm::all_of(getBody().getArgumentTypes(),
1681 [](Type ty) {
return ty.
isIndex(); }))
1682 return emitError(
"all body arguments must be index");
1683 if (getBody().getNumArguments() != resultTy.getRank())
1684 return emitError(
"must have one body argument per input dimension");
1687 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1689 if (yieldOp.getValue().getType() != resultTy.getElementType())
1691 "body must be terminated with a `yield` operation of the tensor "
1697void GenerateOp::build(
1698 OpBuilder &
b, OperationState &
result, Type resultTy,
1701 build(
b,
result, resultTy, dynamicExtents);
1704 OpBuilder::InsertionGuard guard(
b);
1705 Region *bodyRegion =
result.regions.front().get();
1706 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1707 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1708 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1710 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1720struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1721 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1723 LogicalResult matchAndRewrite(GenerateOp generateOp,
1724 PatternRewriter &rewriter)
const final {
1725 SmallVector<Value> foldedDynamicSizes;
1727 generateOp.getType(), generateOp.getDynamicExtents(),
1728 foldedDynamicSizes);
1731 if (foldedTensorType == generateOp.getType())
1734 auto loc = generateOp.getLoc();
1736 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1738 newOp.getBody().begin());
1740 generateOp.getType(), newOp);
1756struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1757 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1759 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1760 PatternRewriter &rewriter)
const final {
1761 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1766 Block *body = &tensorFromElements.getBody().front();
1769 rewriter.
clone(op, mapping);
1780void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1781 MLIRContext *context) {
1783 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1790void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1791 setNameFn(getResult(),
"rank");
1794OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1796 auto type = getOperand().getType();
1797 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1798 if (shapedType && shapedType.hasRank())
1799 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1800 return IntegerAttr();
1807void ReshapeOp::getAsmResultNames(
1809 setNameFn(getResult(),
"reshape");
1814 for (
auto dim : type.getShape())
1819LogicalResult ReshapeOp::verify() {
1820 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1821 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1824 return emitOpError(
"element types of source and destination tensor "
1825 "types should be the same");
1829 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1830 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1832 if (resultRankedType) {
1833 if (operandRankedType && resultRankedType.hasStaticShape() &&
1834 operandRankedType.hasStaticShape()) {
1836 return emitOpError(
"source and destination tensor should have the "
1837 "same number of elements");
1839 if (ShapedType::isDynamic(shapeSize))
1840 return emitOpError(
"cannot use shape operand with dynamic length to "
1841 "reshape to statically-ranked tensor type");
1842 if (shapeSize != resultRankedType.getRank())
1844 "length of shape operand differs from the result's tensor rank");
1849OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1850 if (OpFoldResult reshapedSource = reshapeConstantSource(
1851 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1853 return reshapedSource;
1858 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1859 getSourceMutable().assign(reshapeOpProducer.getSource());
1863 auto source = getSource();
1864 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1865 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1866 if (!sourceTy || !resultTy || sourceTy != resultTy)
1871 if (sourceTy.getRank() <= 1)
1874 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1875 auto elements = fromElements.getElements();
1877 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1878 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1879 auto element = elements[id];
1882 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1886 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1887 dynamicNoop &= dimOp.getSource() == source;
1891 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1895 dynamicNoop =
false;
1910void CollapseShapeOp::getAsmResultNames(
1912 setNameFn(getResult(),
"collapsed");
1915void ExpandShapeOp::getAsmResultNames(
1917 setNameFn(getResult(),
"expanded");
1920int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1921 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1922 "invalid resultDim");
1923 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1924 if (llvm::is_contained(it.value(), resultDim))
1926 llvm_unreachable(
"could not find reassociation group");
1929FailureOr<SmallVector<OpFoldResult>>
1930ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1931 RankedTensorType expandedType,
1932 ArrayRef<ReassociationIndices> reassociation,
1933 ArrayRef<OpFoldResult> inputShape) {
1934 std::optional<SmallVector<OpFoldResult>> outputShape =
1939 return *outputShape;
1942SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1946void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1947 Type resultType, Value src,
1948 ArrayRef<ReassociationIndices> reassociation,
1949 ArrayRef<OpFoldResult> outputShape) {
1950 auto [staticOutputShape, dynamicOutputShape] =
1952 build(builder,
result, cast<RankedTensorType>(resultType), src,
1954 dynamicOutputShape, staticOutputShape);
1957void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1958 Type resultType, Value src,
1959 ArrayRef<ReassociationIndices> reassociation) {
1960 SmallVector<OpFoldResult> inputShape =
1962 auto tensorResultTy = cast<RankedTensorType>(resultType);
1963 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1964 builder,
result.location, tensorResultTy, reassociation, inputShape);
1965 SmallVector<OpFoldResult> outputShapeOrEmpty;
1966 if (succeeded(outputShape)) {
1967 outputShapeOrEmpty = *outputShape;
1969 build(builder,
result, tensorResultTy, src, reassociation,
1970 outputShapeOrEmpty);
1973SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1976SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1978 getReassociationIndices());
1981SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1984SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1986 getReassociationIndices());
1989RankedTensorType CollapseShapeOp::inferCollapsedType(
1990 RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
1991 return inferCollapsedType(
1993 type.getContext(), reassociation)));
1999CollapseShapeOp::inferCollapsedType(RankedTensorType type,
2000 ArrayRef<AffineMap> reassociation) {
2001 auto shape = type.getShape();
2002 SmallVector<int64_t, 4> newShape;
2003 newShape.reserve(reassociation.size());
2008 unsigned currentDim = 0;
2009 for (AffineMap m : reassociation) {
2010 unsigned dim = m.getNumResults();
2011 auto band = shape.slice(currentDim, dim);
2013 if (llvm::is_contained(band, ShapedType::kDynamic))
2014 size = ShapedType::kDynamic;
2016 for (
unsigned d = 0; d < dim; ++d)
2017 size *= shape[currentDim + d];
2018 newShape.push_back(size);
2022 return RankedTensorType::get(newShape, type.getElementType());
2025void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2026 ArrayRef<ReassociationIndices> reassociation,
2027 ArrayRef<NamedAttribute> attrs) {
2028 auto resultType = inferCollapsedType(
2029 llvm::cast<RankedTensorType>(src.
getType()),
2032 result.addAttribute(getReassociationAttrStrName(),
2034 build(
b,
result, resultType, src, attrs);
2037template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2038 TensorReshapeOp, ExpandShapeOp>::value>
2040 RankedTensorType expandedType,
2041 RankedTensorType collapsedType) {
2043 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2046 auto maps = op.getReassociationMaps();
2047 RankedTensorType expectedType =
2048 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2050 return op.emitOpError(
"expected collapsed type to be ")
2051 << expectedType <<
", but got " << collapsedType;
2055LogicalResult ExpandShapeOp::verify() {
2056 auto srcType = getSrcType();
2057 auto resultType = getResultType();
2059 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2060 return emitOpError(
"expected number of static shape dims to be equal to "
2061 "the output rank (")
2062 << resultType.getRank() <<
") but found "
2063 << getStaticOutputShape().size() <<
" inputs instead";
2065 if ((int64_t)getOutputShape().size() !=
2066 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2067 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2068 "static_output_shape: static_output_shape has ")
2069 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2070 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2076LogicalResult CollapseShapeOp::verify() {
2083template <
typename TensorReshapeOp>
2084struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2085 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2086 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2087 PatternRewriter &rewriter)
const override {
2088 DenseElementsAttr attr;
2094 reshapeOp.getResultType(), attr.
getRawData());
2101template <
typename TensorReshapeOp>
2102class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2104 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2106 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2107 PatternRewriter &rewriter)
const override {
2108 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2109 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2113 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2120template <
typename TensorReshapeOp>
2121struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2122 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2123 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2124 PatternRewriter &rewriter)
const override {
2126 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2130 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2132 if (!shapedTy.hasStaticShape())
2136 fromElements.getElements());
2142struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2143 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2145 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2146 PatternRewriter &rewriter)
const override {
2147 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2151 RankedTensorType srcType =
2152 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2153 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2154 srcType, collapseShapeOp.getReassociationMaps());
2156 if (newResultType == collapseShapeOp.getResultType()) {
2158 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2161 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2162 newResultType, castOp.getSource(),
2163 collapseShapeOp.getReassociation());
2165 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2175struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2176 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2178 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2179 PatternRewriter &rewriter)
const override {
2180 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2184 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2185 SmallVector<ReassociationIndices, 4> reassoc =
2186 expandOp.getReassociationIndices();
2188 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2189 SmallVector<Value> dynamicOutputShape;
2190 auto outputIt = expandOp.getOutputShape().begin();
2192 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2193 for (uint64_t outDim : innerReassoc) {
2194 if (ShapedType::isStatic(newOutputShape[outDim]))
2201 Value val = *outputIt;
2203 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2204 dynamicOutputShape.push_back(val);
2210 newOutputShape[outDim] = cst.getSExtValue();
2212 dynamicOutputShape.push_back(val);
2218 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2222 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2223 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2224 for (
auto outDim : reassoc[inDim]) {
2225 auto ofr = newOutputShape[outDim];
2226 if (ShapedType::isDynamic(ofr)) {
2227 newInputShape[inDim] = ShapedType::kDynamic;
2230 newInputShape[inDim] *= ofr;
2234 SmallVector<OpFoldResult> outputOfr =
2236 auto inputType = RankedTensorType::get(
2237 newInputShape, expandOp.getSrcType().getElementType());
2238 auto outputType = RankedTensorType::get(
2239 newOutputShape, expandOp.getSrcType().getElementType());
2240 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2242 auto newExpand = ExpandShapeOp::create(
2243 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2244 expandOp.getReassociationIndices(), outputOfr);
2246 newExpand.getResult());
2252void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2253 MLIRContext *context) {
2255 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2256 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
2257 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2258 FoldReshapeWithSplat<ExpandShapeOp>,
2259 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2262void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2263 MLIRContext *context) {
2265 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2266 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2267 tensor::DimOp, RankedTensorType>,
2268 FoldReshapeWithConstant<CollapseShapeOp>,
2269 FoldReshapeWithSplat<CollapseShapeOp>,
2270 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2274OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2276 adaptor.getOperands());
2279OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2281 adaptor.getOperands());
2288void ExtractSliceOp::getAsmResultNames(
2290 setNameFn(getResult(),
"extracted_slice");
2296RankedTensorType ExtractSliceOp::inferResultType(
2297 RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
2298 ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
2302 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2303 sourceTensorType.getRank() &&
2304 "unexpected staticSizes not equal to rank of source");
2305 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2306 sourceTensorType.getEncoding());
2310RankedTensorType ExtractSliceOp::inferResultType(
2311 RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
2312 ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
2313 SmallVector<int64_t> staticSizes;
2315 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2316 sourceTensorType.getRank() &&
2317 "unexpected staticSizes not equal to rank of source");
2318 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2319 sourceTensorType.getEncoding());
2330RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2331 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2332 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2333 ArrayRef<int64_t> strides) {
2335 auto inferredType = llvm::cast<RankedTensorType>(
2336 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2337 int rankDiff = inferredType.getRank() - desiredResultRank;
2339 auto shape = inferredType.getShape();
2340 llvm::SmallBitVector dimsToProject =
2342 SmallVector<int64_t> projectedShape;
2344 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2345 if (!dimsToProject.test(pos))
2346 projectedShape.push_back(shape[pos]);
2348 RankedTensorType::get(projectedShape, inferredType.getElementType());
2350 return inferredType;
2353RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2354 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2355 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2356 ArrayRef<OpFoldResult> strides) {
2357 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2358 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2362 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2363 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2369void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2370 RankedTensorType resultType, Value source,
2371 ArrayRef<OpFoldResult> offsets,
2372 ArrayRef<OpFoldResult> sizes,
2373 ArrayRef<OpFoldResult> strides,
2374 ArrayRef<NamedAttribute> attrs) {
2375 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2376 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2380 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2383 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2384 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2386 result.addAttributes(attrs);
2387 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2388 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2389 b.getDenseI64ArrayAttr(staticSizes),
2390 b.getDenseI64ArrayAttr(staticStrides));
2395void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2396 ArrayRef<OpFoldResult> offsets,
2397 ArrayRef<OpFoldResult> sizes,
2398 ArrayRef<OpFoldResult> strides,
2399 ArrayRef<NamedAttribute> attrs) {
2400 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2405void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2406 ArrayRef<Range> ranges,
2407 ArrayRef<NamedAttribute> attrs) {
2409 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2414void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2415 RankedTensorType resultType, Value source,
2417 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2418 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2419 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2420 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2421 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2422 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2423 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2424 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2428void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2430 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2431 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2436 RankedTensorType expectedType) {
2441 return op->
emitError(
"expected rank to be smaller or equal to ")
2442 <<
"the other rank. ";
2444 return op->
emitError(
"expected type to be ")
2445 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2447 return op->
emitError(
"expected element type to be ")
2448 << expectedType.getElementType();
2450 llvm_unreachable(
"unexpected extract_slice op verification result");
2455LogicalResult ExtractSliceOp::verify() {
2456 RankedTensorType sourceType = getSourceType();
2459 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2460 sourceType, getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2468 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2469 getStaticStrides(),
true);
2471 return getOperation()->emitError(boundsResult.
errorMessage);
2476llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2481ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2482 ArrayRef<int64_t> desiredShape) {
2483 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2484 assert(sourceTensorType &&
"not a ranked tensor type");
2485 auto sourceShape = sourceTensorType.getShape();
2486 if (sourceShape.equals(desiredShape))
2488 auto maybeRankReductionMask =
2490 if (!maybeRankReductionMask)
2494 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2497LogicalResult ExtractSliceOp::reifyResultShapes(
2499 reifiedReturnShapes.resize(1);
2500 reifiedReturnShapes[0].reserve(
getType().getRank());
2503 for (
const auto &size :
enumerate(mixedSizes)) {
2504 if (droppedDims.test(size.index()))
2506 reifiedReturnShapes[0].push_back(size.value());
2527class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2529 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2531 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2532 PatternRewriter &rewriter)
const override {
2534 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2535 return matchPattern(operand, matchConstantIndex());
2539 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2548 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2549 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2550 sliceOp.getStaticStrides());
2555 Location loc = sliceOp.getLoc();
2556 Value newResult = ExtractSliceOp::create(
2557 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2558 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2559 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2560 sliceOp.getStaticStrides());
2569template <
typename IterTy,
typename ElemTy>
2570static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2571 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2572 ArrayRef<int64_t> strides,
2573 llvm::SmallVectorImpl<ElemTy> *outValues) {
2574 assert(offsets.size() == sizes.size());
2575 assert(offsets.size() == strides.size());
2576 if (offsets.empty())
2579 int64_t offset = offsets.front();
2580 int64_t size = sizes.front();
2581 int64_t stride = strides.front();
2582 if (offsets.size() == 1) {
2583 for (int64_t i = 0; i < size; ++i, offset += stride)
2584 outValues->push_back(*(values + offset));
2589 for (int64_t i = 0; i < size; ++i, offset += stride) {
2590 auto begin = values + offset * counts.front();
2591 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2592 offsets.drop_front(), sizes.drop_front(),
2593 strides.drop_front(), outValues);
2600class ConstantOpExtractSliceFolder final
2601 :
public OpRewritePattern<ExtractSliceOp> {
2603 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2605 ConstantOpExtractSliceFolder(MLIRContext *context,
2607 : OpRewritePattern<ExtractSliceOp>(context),
2608 controlFn(std::move(controlFn)) {}
2610 LogicalResult matchAndRewrite(ExtractSliceOp op,
2611 PatternRewriter &rewriter)
const override {
2612 DenseElementsAttr attr;
2621 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2622 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2623 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2630 int64_t count = sourceType.getNumElements();
2635 auto offsets = op.getStaticOffsets();
2636 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2638 auto sizes = op.getStaticSizes();
2639 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2641 auto strides = op.getStaticStrides();
2642 if (llvm::is_contained(strides, ShapedType::kDynamic))
2646 SmallVector<int64_t> counts;
2647 ArrayRef<int64_t> shape = sourceType.getShape();
2648 counts.reserve(shape.size());
2649 for (int64_t v : shape) {
2651 counts.push_back(count);
2655 DenseElementsAttr newAttr;
2657 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2658 SmallVector<APInt> outValues;
2659 outValues.reserve(sourceType.getNumElements());
2660 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2661 elems.begin(), counts, offsets, sizes, strides, &outValues);
2663 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2664 SmallVector<APFloat> outValues;
2665 outValues.reserve(sourceType.getNumElements());
2666 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2667 elems.begin(), counts, offsets, sizes, strides, &outValues);
2690 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2699 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2700 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2708 ExtractSliceOp newOp) {
2711 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2717void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2718 MLIRContext *context) {
2720 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2721 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2722 ExtractSliceOpCastFolder>(context);
2728 ShapedType shapedType) {
2735 auto shape = shapedType.getShape();
2736 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2750 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2753 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2754 insertOp.isSameAs(extractOp, isSame))
2755 return insertOp.getSource();
2760OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2761 if (OpFoldResult reshapedSource = reshapeConstantSource(
2762 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2764 return reshapedSource;
2765 if (getSourceType() ==
getType() &&
2767 return this->getSource();
2771 return OpFoldResult();
2776 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2777 unsigned rank = rankedTensorType.getRank();
2781 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2782 offsets, sizes, strides);
2789void InsertSliceOp::getAsmResultNames(
2791 setNameFn(getResult(),
"inserted_slice");
2805 result.addAttributes(attrs);
2806 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2807 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2808 b.getDenseI64ArrayAttr(staticSizes),
2809 b.getDenseI64ArrayAttr(staticStrides));
2814void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2815 Value dest, ArrayRef<Range> ranges,
2816 ArrayRef<NamedAttribute> attrs) {
2818 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2822void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2824 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2825 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2826 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
2827 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
2828 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
2829 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2830 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
2831 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2837 RankedTensorType srcType, RankedTensorType dstType,
2842 RankedTensorType expected = ExtractSliceOp::inferResultType(
2843 dstType, staticOffsets, staticSizes, staticStrides);
2845 *expectedType = expected;
2850LogicalResult InsertSliceOp::verify() {
2852 RankedTensorType expectedType;
2855 getStaticSizes(), getStaticStrides(), &expectedType);
2862 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2863 getStaticStrides(),
true);
2865 return getOperation()->emitError(boundsResult.
errorMessage);
2888 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2891 if (!prevInsertOp ||
2892 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2893 !prevInsertOp.isSameAs(insertOp, isSame))
2896 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2908 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2911 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2912 !extractOp.isSameAs(insertOp, isSame))
2915 return extractOp.getSource();
2918OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2919 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2920 getSourceType() ==
getType() &&
2922 return this->getSource();
2929 return OpFoldResult();
2932LogicalResult InsertSliceOp::reifyResultShapes(
2934 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2943template <
typename InsertOpTy>
2944class InsertSliceOpConstantArgumentFolder final
2945 :
public OpRewritePattern<InsertOpTy> {
2947 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2949 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2950 PatternRewriter &rewriter)
const override {
2951 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2952 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2953 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
2962 SliceBoundsVerificationResult sliceResult =
2964 mixedOffsets, mixedSizes, mixedStrides);
2969 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2970 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2971 mixedOffsets, mixedSizes, mixedStrides);
2972 Value toInsert = insertSliceOp.getSource();
2973 if (sourceType != insertSliceOp.getSourceType()) {
2974 OpBuilder::InsertionGuard g(rewriter);
2978 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
2980 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
2981 sourceType, toInsert);
2984 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2985 mixedSizes, mixedStrides);
3010template <
typename InsertOpTy>
3011struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3012 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3014 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3015 PatternRewriter &rewriter)
const override {
3016 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3017 return matchPattern(operand, matchConstantIndex());
3021 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3024 return std::nullopt;
3025 return castOp.getSource();
3027 std::optional<Value> sourceCastSource =
3028 getSourceOfCastOp(insertSliceOp.getSource());
3029 std::optional<Value> destCastSource =
3030 getSourceOfCastOp(insertSliceOp.getDest());
3031 if (!sourceCastSource && !destCastSource)
3035 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3036 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3037 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3038 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3039 if (!srcType || !dstType)
3045 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3047 staticSizes, srcType.getShape(),
true);
3048 if (!rankReductionMask.has_value())
3055 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3056 int64_t rankReducedIdx = 0;
3057 for (
auto [idx, size] :
enumerate(staticSizes)) {
3058 if (!rankReductionMask.value().contains(idx) &&
3059 !srcType.isDynamicDim(rankReducedIdx)) {
3061 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3062 size = srcType.getDimSize(rankReducedIdx++);
3068 staticSizes, insertSliceOp.getStaticStrides()) !=
3069 SliceVerificationResult::Success)
3071 SliceBoundsVerificationResult sliceResult =
3073 mixedSizes, insertSliceOp.getMixedStrides());
3078 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3079 insertSliceOp.getMixedOffsets(), mixedSizes,
3080 insertSliceOp.getMixedStrides());
3083 bool isParallelInsert =
3084 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3085 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3086 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3087 insertSliceOp.getDestType(),
3116template <
typename InsertOpTy>
3117struct InsertSliceOpSourceCastInserter final
3118 :
public OpRewritePattern<InsertOpTy> {
3119 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3121 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3122 PatternRewriter &rewriter)
const override {
3123 RankedTensorType srcType = insertSliceOp.getSourceType();
3124 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3126 SmallVector<int64_t> newSrcShape(srcType.getShape());
3127 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3128 if (std::optional<int64_t> constInt =
3133 newSrcShape[i] = *constInt;
3139 RankedTensorType newSrcType = RankedTensorType::get(
3140 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3141 if (srcType == newSrcType ||
3143 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3151 OpBuilder::InsertionGuard g(rewriter);
3155 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3157 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3158 newSrcType, insertSliceOp.getSource());
3160 insertSliceOp, cast, insertSliceOp.getDest(),
3161 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3162 insertSliceOp.getMixedStrides());
3168llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3172void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3173 MLIRContext *context) {
3174 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3175 InsertSliceOpCastFolder<InsertSliceOp>,
3176 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3183 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3184 unsigned rank = rankedTensorType.getRank();
3188 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3197 setNameFn(getResult(),
"padded");
3200LogicalResult PadOp::verify() {
3201 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3202 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3204 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3205 if (!expectedType) {
3206 return emitError(
"failed to infer expectedType from sourceType ")
3207 << sourceType <<
", specified resultType is " << resultType;
3209 if (resultType.getRank() != expectedType.getRank()) {
3211 << resultType <<
" does not match the inferred type "
3214 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3215 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3217 if (expectedType.isDynamicDim(i))
3220 << resultType <<
" does not match the inferred type "
3227LogicalResult PadOp::verifyRegions() {
3228 auto ®ion = getRegion();
3229 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3230 Block &block = region.front();
3232 return emitError(
"expected the block to have ") << rank <<
" arguments";
3236 if (!en.value().isIndex())
3238 << (en.index() + 1) <<
" to be an index";
3243 if (yieldOp.getValue().getType() !=
3245 return emitOpError(
"expected yield type to match shape element type");
3250RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3251 ArrayRef<int64_t> staticLow,
3252 ArrayRef<int64_t> staticHigh,
3253 ArrayRef<int64_t> resultShape) {
3254 unsigned rank = sourceType.getRank();
3255 if (staticLow.size() != rank)
3256 return RankedTensorType();
3257 if (staticHigh.size() != rank)
3258 return RankedTensorType();
3259 if (!resultShape.empty() && resultShape.size() != rank)
3260 return RankedTensorType();
3262 SmallVector<int64_t, 4> inferredShape;
3263 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3264 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3265 staticHigh[i] == ShapedType::kDynamic) {
3266 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3269 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3270 assert((resultShape.empty() || size == resultShape[i] ||
3271 resultShape[i] == ShapedType::kDynamic) &&
3272 "mismatch between inferred shape and result shape");
3273 inferredShape.push_back(size);
3277 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3280void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3281 Value source, ArrayRef<int64_t> staticLow,
3283 bool nofold, ArrayRef<NamedAttribute> attrs) {
3284 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3286 resultType = inferResultType(sourceType, staticLow, staticHigh);
3287 result.addAttributes(attrs);
3288 build(
b,
result, resultType, source, low, high,
3289 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3290 nofold ?
b.getUnitAttr() : UnitAttr());
3293void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3295 ArrayRef<NamedAttribute> attrs) {
3296 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3297 unsigned rank = sourceType.getRank();
3298 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3299 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3303void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3304 Value source, ArrayRef<OpFoldResult> low,
3305 ArrayRef<OpFoldResult> high,
bool nofold,
3306 ArrayRef<NamedAttribute> attrs) {
3307 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3308 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3309 SmallVector<int64_t, 4> staticLow, staticHigh;
3317 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3319 assert(llvm::isa<RankedTensorType>(resultType));
3320 result.addAttributes(attrs);
3321 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3322 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3323 nofold ?
b.getUnitAttr() : UnitAttr());
3326void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3327 Value source, ArrayRef<OpFoldResult> low,
3328 ArrayRef<OpFoldResult> high, Value constantPadValue,
3329 bool nofold, ArrayRef<NamedAttribute> attrs) {
3330 build(
b,
result, resultType, source, low, high, nofold, attrs);
3333 Region *region =
result.regions[0].get();
3334 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3335 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3336 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3340 OpBuilder::InsertionGuard guard(
b);
3341 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3342 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3345llvm::SmallBitVector PadOp::getPaddedDims() {
3346 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3347 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3348 for (
const auto &en :
enumerate(paddingWidths))
3350 paddedDims.set(en.index());
3352 extractPaddedDims(getMixedLowPad());
3353 extractPaddedDims(getMixedHighPad());
3360struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3361 using OpRewritePattern<PadOp>::OpRewritePattern;
3363 LogicalResult matchAndRewrite(PadOp padTensorOp,
3364 PatternRewriter &rewriter)
const override {
3365 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3367 if (padTensorOp.getNofold())
3370 padTensorOp, padTensorOp.getResult().
getType(),
3371 padTensorOp.getSource());
3377struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3378 using OpRewritePattern<PadOp>::OpRewritePattern;
3380 LogicalResult matchAndRewrite(PadOp padTensorOp,
3381 PatternRewriter &rewriter)
const override {
3382 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3386 auto newResultType = PadOp::inferResultType(
3387 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3388 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3389 padTensorOp.getResultType().getShape());
3391 if (newResultType == padTensorOp.getResultType()) {
3393 padTensorOp.getSourceMutable().assign(castOp.getSource());
3396 auto newOp = PadOp::create(
3397 rewriter, padTensorOp->getLoc(), newResultType,
3398 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3399 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3400 padTensorOp.getHigh(), padTensorOp.getNofold(),
3403 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3406 padTensorOp, padTensorOp.getResultType(), newOp);
3414struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3415 using OpRewritePattern<PadOp>::OpRewritePattern;
3417 LogicalResult matchAndRewrite(PadOp padTensorOp,
3418 PatternRewriter &rewriter)
const override {
3419 if (!padTensorOp.getResult().hasOneUse())
3422 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3426 tensorCastOp.getDest().getType()))
3429 auto replacementOp = PadOp::create(
3430 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3431 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3432 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3433 padTensorOp.getHigh(), padTensorOp.getNofold(),
3435 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3437 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3438 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3478struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3479 using OpRewritePattern<PadOp>::OpRewritePattern;
3481 LogicalResult matchAndRewrite(PadOp padOp,
3482 PatternRewriter &rewriter)
const override {
3483 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3486 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3487 if (!outerPadOp || outerPadOp.getNofold())
3489 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3494 int64_t rank = padOp.getSourceType().getRank();
3495 if (outerSliceOp.getSourceType().getRank() != rank) {
3497 "cannot fold rank-reducing chain");
3501 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3503 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3507 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3509 "cannot fold PadOps with low padding");
3513 Attribute innerAttr, outerAttr;
3514 Value innerValue = padOp.getConstantPaddingValue();
3515 Value outerValue = outerPadOp.getConstantPaddingValue();
3516 if (!innerValue || !outerValue ||
3519 innerAttr != outerAttr) {
3521 padOp,
"cannot fold PadOps with different padding values");
3525 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3526 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3527 if (innerDims.anyCommon(outerDims)) {
3529 padOp,
"cannot fold PadOps with common padding dimensions");
3537 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3539 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3540 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3541 if (!innerDims.test(en.index()) &&
3543 en.value() = outerOffset;
3546 if (!outerDims.test(en.index()) &&
3548 en.value() = innerOffset;
3552 padOp,
"cannot find zero-offset and zero-padding pair");
3560 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3562 if (!outerDims.test(en.index()))
3564 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3565 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3566 assert(ShapedType::isStatic(sourceSize) &&
3567 "expected padded dimension to have a static size");
3570 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3571 "match the size of the outer padding");
3573 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3577 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3579 if (innerDims.test(en.index()))
3580 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3581 if (outerDims.test(en.index()))
3582 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3587 auto newSliceOp = ExtractSliceOp::create(
3588 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3589 newSizes, innerSliceOp.getMixedStrides());
3590 auto newPadOp = PadOp::create(
3591 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3592 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3595 newPadOp.getRegion().begin());
3596 rewriter.
replaceOp(padOp, newPadOp.getResult());
3601struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3602 using OpRewritePattern<PadOp>::OpRewritePattern;
3604 LogicalResult matchAndRewrite(PadOp padTensorOp,
3605 PatternRewriter &rewriter)
const override {
3606 Value input = padTensorOp.getSource();
3607 if (!llvm::isa<RankedTensorType>(input.
getType()))
3609 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3610 auto inputRank = inputDims.size();
3612 auto oldResultType =
3613 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3617 auto outputDims = oldResultType.getShape();
3620 SmallVector<int64_t> constOperandsLow;
3621 SmallVector<Value> newLows;
3622 for (
auto operand : padTensorOp.getLow()) {
3625 constOperandsLow.push_back(ShapedType::kDynamic);
3626 newLows.push_back(operand);
3629 constOperandsLow.push_back(intOp.getExtValue());
3631 SmallVector<int64_t> constOperandsHigh;
3632 SmallVector<Value> newHighs;
3633 for (
auto operand : padTensorOp.getHigh()) {
3636 constOperandsHigh.push_back(ShapedType::kDynamic);
3637 newHighs.push_back(operand);
3640 constOperandsHigh.push_back(intOp.getExtValue());
3643 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3644 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3647 if (inputDims.size() != outputDims.size() ||
3648 inputDims.size() != constLow.size() ||
3649 inputDims.size() != constHigh.size())
3654 for (
size_t i = 0; i < inputRank; i++) {
3655 if (constLow[i] == ShapedType::kDynamic)
3656 constLow[i] = constOperandsLow[lowCount++];
3657 if (constHigh[i] == ShapedType::kDynamic)
3658 constHigh[i] = constOperandsHigh[highCount++];
3661 auto staticLow = ArrayRef<int64_t>(constLow);
3662 auto staticHigh = ArrayRef<int64_t>(constHigh);
3665 SmallVector<int64_t> newOutDims;
3666 for (
size_t i = 0; i < inputRank; i++) {
3667 if (outputDims[i] == ShapedType::kDynamic) {
3668 newOutDims.push_back(
3669 (staticLow[i] == ShapedType::kDynamic ||
3670 staticHigh[i] == ShapedType::kDynamic ||
3671 inputDims[i] == ShapedType::kDynamic
3672 ? ShapedType::kDynamic
3673 : inputDims[i] + staticLow[i] + staticHigh[i]));
3675 newOutDims.push_back(outputDims[i]);
3679 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3680 llvm::all_of(newOutDims,
3681 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3685 auto newResultType = RankedTensorType::get(
3686 newOutDims, padTensorOp.getType().getElementType());
3687 auto newOp = PadOp::create(
3688 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3689 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3693 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3721struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3722 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3724 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3725 PatternRewriter &rewriter)
const override {
3726 if (padOp.getNofold()) {
3730 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3731 if (!producerPad || producerPad.getNofold()) {
3733 padOp,
"producer is not a foldable tensor.pad op");
3737 Value consumerPadValue = padOp.getConstantPaddingValue();
3738 Value producerPadValue = producerPad.getConstantPaddingValue();
3739 if (!consumerPadValue || !producerPadValue ||
3740 consumerPadValue != producerPadValue) {
3743 "cannot fold PadOps with different or non-constant padding values");
3746 Location loc = padOp.getLoc();
3751 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3752 ArrayRef<OpFoldResult> producerPaddings) {
3753 SmallVector<OpFoldResult> sumPaddings;
3754 for (
auto [consumerIndex, producerIndex] :
3755 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3757 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3762 SmallVector<OpFoldResult> newHighPad =
3763 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3764 SmallVector<OpFoldResult> newLowPad =
3765 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3767 auto newPadOp = tensor::PadOp::create(
3768 rewriter, padOp.getLoc(), padOp.getResultType(),
3769 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3772 newPadOp.getRegion().begin());
3773 rewriter.
replaceOp(padOp, newPadOp.getResult());
3781PadOp::reifyResultShapes(OpBuilder &
b,
3783 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3784 SmallVector<OpFoldResult> lp = getMixedLowPad();
3785 SmallVector<OpFoldResult> hp = getMixedHighPad();
3786 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3787 if (!
getType().isDynamicDim(i)) {
3788 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3791 Location loc = getLoc();
3792 Value dim =
b.createOrFold<tensor::DimOp>(
3795 AffineExpr d0, d1, d2;
3798 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3803void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3804 MLIRContext *context) {
3805 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3806 FoldOrthogonalPaddings, FoldStaticPadding,
3807 FoldConsecutiveConstantPadding>(context);
3819Value PadOp::getConstantPaddingValue() {
3820 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3823 Value padValue = yieldOp.getValue();
3834OpFoldResult PadOp::fold(FoldAdaptor) {
3835 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3845OpResult ParallelInsertSliceOp::getTiedOpResult() {
3846 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3847 for (
const auto &it :
3848 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3849 Operation &nextOp = it.value();
3850 if (&nextOp == getOperation())
3851 return parallelCombiningParent.getParentResult(it.index());
3853 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3857void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3858 Value source, Value dest,
3859 ArrayRef<OpFoldResult> offsets,
3860 ArrayRef<OpFoldResult> sizes,
3861 ArrayRef<OpFoldResult> strides,
3862 ArrayRef<NamedAttribute> attrs) {
3863 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3864 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3868 result.addAttributes(attrs);
3869 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3870 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3871 b.getDenseI64ArrayAttr(staticSizes),
3872 b.getDenseI64ArrayAttr(staticStrides));
3877void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3878 Value source, Value dest,
3879 ArrayRef<Range> ranges,
3880 ArrayRef<NamedAttribute> attrs) {
3882 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3886void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3887 Value source, Value dest,
ValueRange offsets,
3889 ArrayRef<NamedAttribute> attrs) {
3890 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3891 llvm::map_range(offsets, [](Value v) -> OpFoldResult {
return v; }));
3892 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3893 llvm::map_range(sizes, [](Value v) -> OpFoldResult {
return v; }));
3894 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3895 llvm::map_range(strides, [](Value v) -> OpFoldResult {
return v; }));
3896 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3899LogicalResult ParallelInsertSliceOp::verify() {
3900 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3901 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3902 << *(getOperation()->getParentOp());
3905 RankedTensorType expectedType;
3908 getStaticSizes(), getStaticStrides(), &expectedType);
3915 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3916 getStaticStrides(),
true);
3918 return getOperation()->emitError(boundsResult.
errorMessage);
3923void ParallelInsertSliceOp::getCanonicalizationPatterns(
3924 RewritePatternSet &results, MLIRContext *context) {
3925 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3926 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3927 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3930llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3935MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3936 return getDestMutable();
3939Operation *ParallelInsertSliceOp::getIteratingParent() {
3941 if (
auto combiningOp =
3942 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3943 return combiningOp->getParentOp();
3951void ScatterOp::getAsmResultNames(
3953 setNameFn(getResult(),
"scatter");
3956LogicalResult ScatterOp::verify() {
3957 int64_t destRank = getDestType().getRank();
3958 ArrayRef<int64_t> scatterDims = getScatterDims();
3960 getIndicesType().
getShape(), destRank,
3961 "scatter",
"dest")))
3965 return emitOpError(
"requires 'unique' attribute to be set");
3972 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3973 getDestType(), getIndicesType(), scatterDims,
false);
3974 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3975 getDestType(), getIndicesType(), scatterDims,
true);
3976 if (getSourceType() != expectedSourceType &&
3977 getSourceType() != expectedRankReducedSourceType) {
3981 << expectedSourceType <<
" or its rank-reduced variant "
3982 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3993void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
3994 Type aggregateType,
ValueRange dynamicSizes) {
3995 build(builder,
result, aggregateType, element, dynamicSizes);
3998void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
3999 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4000 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4001 build(builder,
result, aggregateType, element, dynamicSizes);
4004void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4005 ArrayRef<OpFoldResult> sizes) {
4006 SmallVector<int64_t> staticShape;
4007 SmallVector<Value> dynamicSizes;
4009 build(builder,
result, element, staticShape, dynamicSizes);
4012void SplatOp::getAsmResultNames(
4014 setNameFn(getResult(),
"splat");
4017LogicalResult SplatOp::verify() {
4019 return emitOpError(
"incorrect number of dynamic sizes, has ")
4021 <<
getType().getNumDynamicDims();
4026SplatOp::reifyResultShapes(OpBuilder &builder,
4028 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4030 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4031 if (
getType().isDynamicDim(i)) {
4040OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4041 auto constOperand = adaptor.getInput();
4042 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4046 if (!
getType().hasStaticShape())
4061 if (isa<InsertSliceOp>(op.getOperation()) ||
4062 isa<LoopLikeOpInterface>(op.getOperation()))
4095 isa<linalg::RelayoutOpInterface>(*op))
4103 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4106 replacements.reserve(newOp->getNumResults());
4107 for (
auto [oldResult, newResult] :
4108 llvm::zip(op->getResults(), newOp->getResults())) {
4109 if (newResult.getType() != oldResult.getType()) {
4110 replacements.push_back(tensor::CastOp::create(
4111 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4113 replacements.push_back(newResult);
4126void TensorDialect::getCanonicalizationPatterns(
4127 RewritePatternSet &results)
const {
4135#define GET_OP_CLASSES
4136#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.