34#include "llvm/ADT/DenseSet.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallBitVector.h"
37#include "llvm/ADT/SmallVectorExtras.h"
38#include "llvm/ADT/StringRef.h"
39#include "llvm/Support/Casting.h"
40#include "llvm/Support/MathExtras.h"
51 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
53 if (complex::ConstantOp::isBuildableWith(value, type))
54 return complex::ConstantOp::create(builder, loc, type,
55 llvm::cast<ArrayAttr>(value));
61 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
62 if (tensorType.isDynamicDim(dim))
63 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
70 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
72 for (
int64_t i = 0; i < tensorType.getRank(); ++i)
79 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
80 assert(tensorType &&
"expected tensor type");
84 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
86 return destOp.getTiedOpOperand(opResult)->get();
94 if (!tensorType.hasStaticShape()) {
102 for (
int64_t sz : tensorType.getShape())
103 mixedSizes.push_back(
b.getIndexAttr(sz));
108 tensor::EmptyOp::create(
b, loc, mixedSizes, tensorType.getElementType());
116 if (llvm::isa<TensorType>(opResult.getType())) {
118 if (failed(destination))
120 result.push_back(*destination);
127 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
128 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
129 return rtp1.getShape() == rtp2.getShape() &&
130 rtp1.getElementType() == rtp2.getElementType();
140 llvm::SmallBitVector droppedDims(mixedSizes.size());
141 int64_t shapePos = reducedShape.size() - 1;
143 for (
const auto &size : enumerate(llvm::reverse(mixedSizes))) {
144 size_t idx = mixedSizes.size() - size.index() - 1;
146 bool isStaticUnitSize =
147 isa<Attribute>(size.value()) &&
148 llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
153 assert(isStaticUnitSize &&
"expected unit dim");
154 droppedDims.set(idx);
159 if (!isStaticUnitSize) {
165 if (reducedShape[shapePos] == 1) {
171 droppedDims.set(idx);
174 assert(shapePos < 0 &&
"dimension mismatch");
181static RankedTensorType
185 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
186 "incorrect number of dynamic sizes");
190 for (
int64_t i = 0, e = type.getRank(); i < e; ++i) {
191 if (type.isDynamicDim(i)) {
192 Value dynamicSize = dynamicSizes[ctr++];
194 if (cst.has_value()) {
196 if (cst.value() < 0) {
197 foldedDynamicSizes.push_back(dynamicSize);
200 staticShape[i] = *cst;
202 foldedDynamicSizes.push_back(dynamicSize);
207 return RankedTensorType::get(staticShape, type.getElementType(),
216 if (inputs.size() != 1 || outputs.size() != 1)
218 Type a = inputs.front(),
b = outputs.front();
219 auto aT = dyn_cast<TensorType>(a);
220 auto bT = dyn_cast<TensorType>(
b);
224 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 using OpRewritePattern<BitcastOp>::OpRewritePattern;
237 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
238 PatternRewriter &rewriter)
const final {
239 auto tensorBitcastOperand =
240 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
241 if (!tensorBitcastOperand)
244 auto resultType = cast<TensorType>(tensorBitcast.getType());
245 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
246 tensorBitcastOperand.getOperand());
255 results.
add<ChainedTensorBitcast>(context);
263 setNameFn(getResult(),
"cast");
269 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
270 auto targetType = llvm::dyn_cast<RankedTensorType>(
target);
273 if (!sourceType || !targetType)
277 if (sourceType.getElementType() != targetType.getElementType())
281 if (sourceType.getRank() != targetType.getRank())
285 if (sourceType.getEncoding() != targetType.getEncoding())
289 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
290 if (ShapedType::isStatic(std::get<0>(t)) &&
291 ShapedType::isDynamic(std::get<1>(t)))
327 castOp.getSource().getType());
360 if (llvm::isa<BlockArgument>(opOperand.get()))
362 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
363 return castOp && canFoldIntoConsumerOp(castOp);
370 newOperands.reserve(op->getNumOperands());
376 for (
OpOperand &opOperand : op->getOpOperands()) {
377 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
379 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
380 if (op.isDpsInit(&opOperand) &&
381 !llvm::isa<MemRefType>(newOperands.back().getType()))
382 newResTy[dpsInitIdx++] = newOperands.back().getType();
392 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
394 operand.set(castOp.getOperand());
402 if (inputs.size() != 1 || outputs.size() != 1)
404 Type a = inputs.front(),
b = outputs.front();
405 auto aT = llvm::dyn_cast<TensorType>(a);
406 auto bT = llvm::dyn_cast<TensorType>(
b);
410 if (aT.getElementType() != bT.getElementType())
427 if (rank != two.getRank())
432 for (
int64_t i = 0; i < rank; ++i) {
433 if (one.isDynamicDim(i)) {
434 join.push_back(two.getDimSize(i));
437 if (two.isDynamicDim(i)) {
438 join.push_back(one.getDimSize(i));
441 if (one.getDimSize(i) != two.getDimSize(i))
443 join.push_back(one.getDimSize(i));
453 using OpRewritePattern<CastOp>::OpRewritePattern;
455 LogicalResult matchAndRewrite(CastOp tensorCast,
456 PatternRewriter &rewriter)
const final {
457 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
459 if (!tensorCastOperand)
463 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
464 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
465 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
479 auto newJoin =
joinShapes(sourceType, resultType);
480 if (firstJoin != newJoin)
483 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
484 tensorCastOperand.getOperand());
502 using OpRewritePattern<CastOp>::OpRewritePattern;
504 LogicalResult matchAndRewrite(CastOp tensorCast,
505 PatternRewriter &rewriter)
const final {
506 auto extractOperand =
507 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
510 auto rankedResultType =
511 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
512 if (!rankedResultType)
516 rankedResultType.getShape() ==
517 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
521 SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
523 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
525 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
526 if (dimMask && dimMask->count(i))
528 int64_t dim = rankedResultType.getShape()[dimIndex++];
529 if (ShapedType::isDynamic(dim))
531 sizes[i] = rewriter.getIndexAttr(dim);
534 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
535 tensorCast, rankedResultType, extractOperand.getSource(),
536 extractOperand.getMixedOffsets(), sizes,
537 extractOperand.getMixedStrides());
546 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
553RankedTensorType ConcatOp::inferResultType(
int64_t dim,
TypeRange inputTypes) {
554 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
556 llvm::map_to_vector<4>(inputTypes, llvm::CastTo<RankedTensorType>);
557 int64_t concatRank = tensorTypes[0].getRank();
560 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
563 for (
int64_t i = 0, e = concatRank; i < e; ++i) {
567 for (
auto tensorType : tensorTypes)
572 for (
auto tensorType : tensorTypes)
575 sizes[dim] = concatSize.asInteger();
576 return RankedTensorType::get(sizes, tensorTypes[0].
getElementType());
581 FailureOr<RankedTensorType> resultType =
582 inferResultType(dim, inputs.
getTypes());
583 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
584 build(builder,
result, *resultType, dim, inputs);
587LogicalResult ConcatOp::verify() {
588 if (getInputs().size() < 1)
592 for (
auto input : getInputs())
593 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
595 RankedTensorType resultType = getResultType();
596 int64_t resultRank = getRank();
597 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
598 return type.getRank() != resultRank;
600 return emitOpError(
"rank of concatenated inputs must match result rank");
602 Type resultElementType = resultType.getElementType();
603 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
604 return type.getElementType() != resultElementType;
606 return emitOpError(
"inputs and result element type must match");
609 if (dim >= resultRank)
610 return emitOpError(
"concatenation dim must be less than the tensor rank");
613 for (
int64_t i = 0, e = resultRank; i < e; ++i) {
617 for (
auto tensorType : inputTypes) {
618 FailureOr<SaturatedInteger> maybeSize =
621 return emitOpError(
"static concatenation size mismatch along ")
622 <<
"non-concatenated dimension " << i;
628 for (
auto tensorType : inputTypes)
631 sizes[dim] = concatSize.asInteger();
632 auto inferredResultType =
635 for (
auto [inferredSize, actualSize] :
636 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
637 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
638 ShapedType::isDynamic(actualSize);
639 if (!hasDynamic && inferredSize != actualSize)
641 << resultType <<
"does not match inferred shape "
642 << inferredResultType <<
" static sizes";
648FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
649 size_t numInputs = getInputs().size();
650 uint64_t concatDim = getDim();
653 inputShapes.reserve(numInputs);
655 concatOffsets.reserve(numInputs);
662 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
666 outputShape = inputShape;
667 concatOffsets.push_back(zero);
669 concatOffsets.push_back(outputShape[concatDim]);
671 builder, loc, addExpr,
672 {outputShape[concatDim], inputShape[concatDim]});
674 inputShapes.emplace_back(std::move(inputShape));
684 for (
auto [
index, input] : llvm::enumerate(getInputs())) {
685 offsets[concatDim] = concatOffsets[
index];
686 auto insertSlice = tensor::InsertSliceOp::create(
697ConcatOp::reifyResultShapes(
OpBuilder &builder,
701 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
703 Value init = inputs[0];
711 for (
int64_t i = 0; i < rank; ++i) {
714 if (!
getType().isDynamicDim(i)) {
716 }
else if (!inferredResultType.isDynamicDim(i)) {
719 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
721 reifiedReturnShapes[0][i] =
722 tensor::DimOp::create(builder, init.
getLoc(), init, i).getResult();
726 if (
getType().isDynamicDim(dim)) {
731 for (
auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
734 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
742 reifiedReturnShapes[0][dim] =
748void ConcatOp::getAsmResultNames(
750 setNameFn(getResult(),
"concat");
755 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
763 using OpRewritePattern<ConcatOp>::OpRewritePattern;
765 LogicalResult matchAndRewrite(ConcatOp concatOp,
766 PatternRewriter &rewriter)
const override {
767 if (concatOp.getInputs().size() != 1)
770 concatOp.getInputs()[0]);
795 using OpRewritePattern<ConcatOp>::OpRewritePattern;
797 LogicalResult matchAndRewrite(ConcatOp concatOp,
798 PatternRewriter &rewriter)
const override {
799 int64_t dim = concatOp.getDim();
800 RankedTensorType inferredResultType =
801 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
804 LogicalResult matched = failure();
807 SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
808 for (
auto [operandIdx, operandType] :
809 llvm::enumerate(concatOp->getOperandTypes())) {
811 inferredOperandShape[dim] =
812 cast<RankedTensorType>(operandType).getDimSize(dim);
813 auto inferredOperandType = RankedTensorType::get(
814 inferredOperandShape, inferredResultType.getElementType());
822 CastOp::create(rewriter, concatOp->getLoc(), inferredOperandType,
823 concatOp.getOperand(operandIdx));
825 concatOp->setOperand(operandIdx, castOp->getResult(0));
849 using OpRewritePattern<ConcatOp>::OpRewritePattern;
851 LogicalResult matchAndRewrite(ConcatOp concatOp,
852 PatternRewriter &rewriter)
const override {
853 int64_t dim = concatOp.getDim();
854 RankedTensorType inferredResultType =
855 ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
859 concatOp.getResultType())) {
864 ConcatOp::create(rewriter, concatOp->getLoc(), inferredResultType, dim,
865 concatOp->getOperands());
877 .
add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886 setNameFn(getResult(),
"dim");
891 auto loc =
result.location;
893 build(builder,
result, source, indexValue);
896std::optional<int64_t> DimOp::getConstantIndex() {
905 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
906 if (!rankedSourceType)
909 if (rankedSourceType.getRank() <= constantIndex)
917 setResultRange(getResult(),
923 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
928 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
935 if (indexVal < 0 || indexVal >= tensorType.getRank())
939 if (!tensorType.isDynamicDim(
index.getInt())) {
944 Operation *definingOp = getSource().getDefiningOp();
947 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
949 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
952 assert(ShapedType::isDynamic(resultType.getShape()[
index.getInt()]));
955 auto dynExtents = fromElements.getDynamicExtents().begin();
956 for (
auto dim : resultType.getShape().take_front(
index.getInt()))
957 if (ShapedType::isDynamic(dim))
960 return Value{*dynExtents};
964 unsigned unsignedIndex =
index.getValue().getZExtValue();
966 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
969 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
970 sliceOp.isDynamicSize(unsignedIndex)) {
971 return {sliceOp.getDynamicSize(unsignedIndex)};
985 using OpRewritePattern<DimOp>::OpRewritePattern;
987 LogicalResult matchAndRewrite(DimOp dimOp,
988 PatternRewriter &rewriter)
const override {
989 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
992 Value newSource = castOp.getOperand();
1001 using OpRewritePattern<DimOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(DimOp dimOp,
1004 PatternRewriter &rewriter)
const override {
1005 auto source = dimOp.getSource();
1006 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
1010 auto resultIndex = cast<OpResult>(source).getResultNumber();
1011 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
1014 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
1022 using OpRewritePattern<DimOp>::OpRewritePattern;
1024 LogicalResult matchAndRewrite(DimOp dim,
1025 PatternRewriter &rewriter)
const override {
1026 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1034 Location loc = dim.getLoc();
1036 ExtractOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1037 if (extract.
getType() != dim.getType())
1039 arith::IndexCastOp::create(rewriter, loc, dim.getType(), extract);
1048 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
1058 assert(none_of(staticShape, ShapedType::isDynamic) &&
1059 "expected only static sizes");
1063void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1064 ArrayRef<int64_t> staticShape, Type elementType,
1065 ValueRange dynamicSizes, Attribute encoding) {
1066 auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
1067 build(builder,
result, tensorType, dynamicSizes);
1070void EmptyOp::build(OpBuilder &builder, OperationState &
result,
1071 ArrayRef<OpFoldResult> sizes, Type elementType,
1072 Attribute encoding) {
1073 SmallVector<int64_t> staticShape;
1074 SmallVector<Value> dynamicSizes;
1076 build(builder,
result, staticShape, elementType, dynamicSizes, encoding);
1079LogicalResult EmptyOp::verify() {
1085EmptyOp::reifyResultShapes(OpBuilder &builder,
1087 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1089 for (int64_t i = 0; i <
getType().getRank(); ++i) {
1090 if (
getType().isDynamicDim(i)) {
1099Value EmptyOp::getDynamicSize(
unsigned idx) {
1100 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
1102 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
1103 if (
getType().isDynamicDim(i))
1108SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
1109 SmallVector<OpFoldResult>
result;
1113 if (ShapedType::isDynamic(dim)) {
1116 result.push_back(
b.getIndexAttr(dim));
1134struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
1135 using OpRewritePattern<EmptyOp>::OpRewritePattern;
1137 LogicalResult matchAndRewrite(EmptyOp op,
1138 PatternRewriter &rewriter)
const override {
1139 SmallVector<Value> foldedDynamicSizes;
1141 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1144 if (foldedTensorType == op.getType())
1147 auto newOp = EmptyOp::create(rewriter, op.getLoc(), foldedTensorType,
1148 foldedDynamicSizes);
1154struct FoldEmptyTensorWithDimOp :
public OpRewritePattern<DimOp> {
1155 using OpRewritePattern<DimOp>::OpRewritePattern;
1157 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1158 PatternRewriter &rewriter)
const override {
1159 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1160 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1161 if (!emptyTensorOp || !maybeConstantIndex)
1163 auto emptyTensorType = emptyTensorOp.getType();
1164 if (*maybeConstantIndex < 0 ||
1165 *maybeConstantIndex >= emptyTensorType.getRank() ||
1166 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1169 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1189struct FoldEmptyTensorWithCastOp :
public OpRewritePattern<CastOp> {
1190 using OpRewritePattern<CastOp>::OpRewritePattern;
1192 LogicalResult matchAndRewrite(CastOp castOp,
1193 PatternRewriter &rewriter)
const override {
1196 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1201 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1202 ArrayRef<int64_t> resultShape = resultType.getShape();
1203 SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
1204 SmallVector<OpFoldResult> newMixedSizes;
1205 newMixedSizes.reserve(currMixedSizes.size());
1206 assert(resultShape.size() == currMixedSizes.size() &&
1207 "mismatch in result shape and sizes of empty op");
1208 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1209 int64_t newDim = std::get<0>(it);
1210 OpFoldResult currDim = std::get<1>(it);
1213 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1214 if (ShapedType::isDynamic(newDim) ||
1215 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1220 producer,
"mismatch in static value of shape of empty tensor "
1221 "result and cast result");
1223 newMixedSizes.push_back(attr);
1229 if (ShapedType::isStatic(newDim)) {
1230 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1236 newMixedSizes.push_back(currDim);
1241 resultType.getElementType());
1248void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1251 ReplaceEmptyTensorStaticShapeDims>(context);
1268struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1269 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1271 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1272 PatternRewriter &rewriter)
const final {
1273 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
1276 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1279 extract, tensorCast.getSource(), extract.getIndices());
1294struct ExtractFromCollapseShape :
public OpRewritePattern<tensor::ExtractOp> {
1295 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1297 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1298 PatternRewriter &rewriter)
const final {
1300 extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1303 if (!collapseOp.getSrcType().hasStaticShape())
1306 auto sourceSizes = collapseOp.getSrcType().getShape();
1308 SmallVector<Value>
indices(extractOp.getIndices().begin(),
1309 extractOp.getIndices().end());
1310 SmallVector<Value> sourceIndices;
1311 for (
auto [index, group] :
1312 llvm::zip(
indices, collapseOp.getReassociationIndices())) {
1313 assert(!group.empty() &&
"association indices groups cannot be empty");
1314 auto groupSize = group.size();
1316 if (groupSize == 1) {
1317 sourceIndices.push_back(index);
1321 SmallVector<int64_t> basis =
1322 llvm::map_to_vector(group, [&](int64_t d) {
return sourceSizes[d]; });
1323 auto delinearize = affine::AffineDelinearizeIndexOp::create(
1324 rewriter, extractOp.getLoc(), index, basis,
true);
1325 llvm::append_range(sourceIndices,
delinearize.getResults());
1327 if (collapseOp.getReassociationIndices().empty()) {
1330 cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1331 OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1332 rewriter, extractOp.getLoc(), zeroAffineMap,
1333 ArrayRef<OpFoldResult>{});
1334 for (int64_t i = 0; i < srcRank; i++) {
1335 sourceIndices.push_back(
1341 extractOp, collapseOp.getSrc(), sourceIndices);
1348void ExtractOp::getAsmResultNames(
1350 setNameFn(getResult(),
"extracted");
1353LogicalResult ExtractOp::verify() {
1355 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1356 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1357 return emitOpError(
"incorrect number of indices for extract_element");
1366 auto insertOp = extractOp.getTensor().
getDefiningOp<InsertOp>();
1371 if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1372 llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1373 return insertOp.getScalar();
1378OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1379 if (Attribute tensor = adaptor.getTensor()) {
1382 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1383 return splatTensor.getSplatValue<Attribute>();
1386 if (isa<DenseResourceElementsAttr>(tensor))
1391 SmallVector<uint64_t, 8>
indices;
1392 for (Attribute indice : adaptor.getIndices()) {
1393 if (!indice || !llvm::isa<IntegerAttr>(indice))
1395 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1399 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1400 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1401 auto rank = tensorType.getRank();
1402 assert(
static_cast<int64_t
>(
indices.size()) == tensorType.getRank() &&
1406 for (
int i = rank - 1; i >= 0; --i) {
1407 flatIndex +=
indices[i] * stride;
1408 stride *= tensorType.getDimSize(i);
1412 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1415 return fromElementsOp.getElements()[flatIndex];
1419 if (Attribute tensor = adaptor.getTensor()) {
1420 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1421 if (elementsAttr && elementsAttr.isValidIndex(
indices))
1422 return elementsAttr.getValues<Attribute>()[
indices];
1431void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1432 MLIRContext *context) {
1433 results.
add<ExtractFromTensorCast>(context);
1445void FromElementsOp::getAsmResultNames(
1447 setNameFn(getResult(),
"from_elements");
1452 assert(!elements.empty() &&
"expected at least one element");
1453 Type resultType = RankedTensorType::get(
1454 {
static_cast<int64_t>(elements.size())}, elements.front().
getType());
1455 build(builder,
result, resultType, elements);
1458OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1459 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1482struct ExtractElementFromIndexCast
1483 :
public OpRewritePattern<tensor::ExtractOp> {
1484 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1486 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1487 PatternRewriter &rewriter)
const final {
1488 Location loc = extract.getLoc();
1489 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
1495 auto newExtract = tensor::ExtractOp::create(
1496 rewriter, loc, elementTy, indexCast.getIn(), extract.getIndices());
1507void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
1508 MLIRContext *context) {
1509 results.
add<ExtractElementFromIndexCast>(context);
1516void GatherOp::getAsmResultNames(
1518 setNameFn(getResult(),
"gather");
1533RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1534 RankedTensorType indicesType,
1535 ArrayRef<int64_t> gatherDims,
1537 SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
1538 resultShape.reserve(resultShape.size() + sourceType.getRank());
1539 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1540 if (llvm::binary_search(gatherDims, idx)) {
1542 resultShape.push_back(1);
1545 resultShape.push_back(sourceType.getDimSize(idx));
1547 return RankedTensorType::Builder(sourceType).setShape(resultShape);
1553 StringRef gatherOrScatter, StringRef sourceOrDest) {
1555 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1557 int64_t numGatherDims = dims.size();
1558 if (numGatherDims > rank)
1560 <<
"_dims overflow " << sourceOrDest <<
" rank";
1563 <<
"_dims length must match the size of last dimension of indices";
1567 <<
"_dims value must be non-negative";
1570 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1572 for (
int64_t i = 1; i < numGatherDims; ++i) {
1573 if (dims[i - 1] >= dims[i])
1575 <<
"_dims values must be strictly increasing";
1580LogicalResult GatherOp::verify() {
1581 int64_t sourceRank = getSourceType().getRank();
1582 ArrayRef<int64_t> gatherDims = getGatherDims();
1584 getIndicesType().
getShape(), sourceRank,
1585 "gather",
"source")))
1588 RankedTensorType expectedResultType = GatherOp::inferResultType(
1589 getSourceType(), getIndicesType(), gatherDims,
false);
1590 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1591 getSourceType(), getIndicesType(), gatherDims,
true);
1592 if (getResultType() != expectedResultType &&
1593 getResultType() != expectedRankReducedResultType) {
1597 << expectedResultType <<
" or its rank-reduced variant "
1598 << expectedRankReducedResultType <<
" (got: " << getResultType()
1605OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1606 if (OpFoldResult reshapedSource = reshapeConstantSource(
1607 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1609 return reshapedSource;
1617void InsertOp::getAsmResultNames(
1619 setNameFn(getResult(),
"inserted");
1622LogicalResult InsertOp::verify() {
1624 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1625 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1626 return emitOpError(
"incorrect number of indices");
1630OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1631 Attribute scalar = adaptor.getScalar();
1632 Attribute dest = adaptor.getDest();
1634 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1635 if (scalar == splatDest.getSplatValue<Attribute>())
1644void GenerateOp::getAsmResultNames(
1646 setNameFn(getResult(),
"generated");
1649LogicalResult GenerateOp::reifyResultShapes(
1651 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
1653 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1654 if (
getType().isDynamicDim(dim)) {
1655 reifiedReturnShapes[0][dim] = getOperand(idx++);
1657 reifiedReturnShapes[0][dim] =
1664LogicalResult GenerateOp::verify() {
1667 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1674LogicalResult GenerateOp::verifyRegions() {
1675 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1677 if (!llvm::all_of(getBody().getArgumentTypes(),
1678 [](Type ty) {
return ty.
isIndex(); }))
1679 return emitError(
"all body arguments must be index");
1680 if (getBody().getNumArguments() != resultTy.getRank())
1681 return emitError(
"must have one body argument per input dimension");
1684 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1686 if (yieldOp.getValue().getType() != resultTy.getElementType())
1688 "body must be terminated with a `yield` operation of the tensor "
1694void GenerateOp::build(
1695 OpBuilder &
b, OperationState &
result, Type resultTy,
1698 build(
b,
result, resultTy, dynamicExtents);
1701 OpBuilder::InsertionGuard guard(
b);
1702 Region *bodyRegion =
result.regions.front().get();
1703 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1704 SmallVector<Type, 2> argumentTypes(rank,
b.getIndexType());
1705 SmallVector<Location, 2> argumentLocs(rank,
result.location);
1707 b.createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1717struct StaticTensorGenerate :
public OpRewritePattern<GenerateOp> {
1718 using OpRewritePattern<GenerateOp>::OpRewritePattern;
1720 LogicalResult matchAndRewrite(GenerateOp generateOp,
1721 PatternRewriter &rewriter)
const final {
1722 SmallVector<Value> foldedDynamicSizes;
1724 generateOp.getType(), generateOp.getDynamicExtents(),
1725 foldedDynamicSizes);
1728 if (foldedTensorType == generateOp.getType())
1731 auto loc = generateOp.getLoc();
1733 GenerateOp::create(rewriter, loc, foldedTensorType, foldedDynamicSizes);
1735 newOp.getBody().begin());
1737 generateOp.getType(), newOp);
1753struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1754 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1756 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1757 PatternRewriter &rewriter)
const final {
1758 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1763 Block *body = &tensorFromElements.getBody().front();
1766 rewriter.
clone(op, mapping);
1777void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
1778 MLIRContext *context) {
1780 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1787void RankOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1788 setNameFn(getResult(),
"rank");
1791OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1793 auto type = getOperand().getType();
1794 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1795 if (shapedType && shapedType.hasRank())
1796 return IntegerAttr::get(IndexType::get(
getContext()), shapedType.getRank());
1797 return IntegerAttr();
1804void ReshapeOp::getAsmResultNames(
1806 setNameFn(getResult(),
"reshape");
1811 for (
auto dim : type.getShape())
1816LogicalResult ReshapeOp::verify() {
1817 TensorType operandType = llvm::cast<TensorType>(getSource().
getType());
1818 TensorType resultType = llvm::cast<TensorType>(getResult().
getType());
1821 return emitOpError(
"element types of source and destination tensor "
1822 "types should be the same");
1826 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1827 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1829 if (resultRankedType) {
1830 if (operandRankedType && resultRankedType.hasStaticShape() &&
1831 operandRankedType.hasStaticShape()) {
1833 return emitOpError(
"source and destination tensor should have the "
1834 "same number of elements");
1836 if (ShapedType::isDynamic(shapeSize))
1837 return emitOpError(
"cannot use shape operand with dynamic length to "
1838 "reshape to statically-ranked tensor type");
1839 if (shapeSize != resultRankedType.getRank())
1841 "length of shape operand differs from the result's tensor rank");
1846OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
1847 if (OpFoldResult reshapedSource = reshapeConstantSource(
1848 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1850 return reshapedSource;
1855 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1856 getSourceMutable().assign(reshapeOpProducer.getSource());
1860 auto source = getSource();
1861 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1862 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1863 if (!sourceTy || !resultTy || sourceTy != resultTy)
1868 if (sourceTy.getRank() <= 1)
1871 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1872 auto elements = fromElements.getElements();
1874 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1875 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++
id) {
1876 auto element = elements[id];
1879 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1883 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1884 dynamicNoop &= dimOp.getSource() == source;
1888 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1892 dynamicNoop =
false;
1907void CollapseShapeOp::getAsmResultNames(
1909 setNameFn(getResult(),
"collapsed");
1912void ExpandShapeOp::getAsmResultNames(
1914 setNameFn(getResult(),
"expanded");
1917int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1918 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1919 "invalid resultDim");
1920 for (
const auto &it : llvm::enumerate(getReassociationIndices()))
1921 if (llvm::is_contained(it.value(), resultDim))
1923 llvm_unreachable(
"could not find reassociation group");
1926FailureOr<SmallVector<OpFoldResult>>
1927ExpandShapeOp::inferOutputShape(OpBuilder &
b, Location loc,
1928 RankedTensorType expandedType,
1929 ArrayRef<ReassociationIndices> reassociation,
1930 ArrayRef<OpFoldResult> inputShape) {
1931 std::optional<SmallVector<OpFoldResult>> outputShape =
1936 return *outputShape;
1939SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
1943void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1944 Type resultType, Value src,
1945 ArrayRef<ReassociationIndices> reassociation,
1946 ArrayRef<OpFoldResult> outputShape) {
1947 auto [staticOutputShape, dynamicOutputShape] =
1949 build(builder,
result, cast<RankedTensorType>(resultType), src,
1951 dynamicOutputShape, staticOutputShape);
1954void ExpandShapeOp::build(OpBuilder &builder, OperationState &
result,
1955 Type resultType, Value src,
1956 ArrayRef<ReassociationIndices> reassociation) {
1957 SmallVector<OpFoldResult> inputShape =
1959 auto tensorResultTy = cast<RankedTensorType>(resultType);
1960 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1961 builder,
result.location, tensorResultTy, reassociation, inputShape);
1962 SmallVector<OpFoldResult> outputShapeOrEmpty;
1963 if (succeeded(outputShape)) {
1964 outputShapeOrEmpty = *outputShape;
1966 build(builder,
result, tensorResultTy, src, reassociation,
1967 outputShapeOrEmpty);
1970SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1973SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1975 getReassociationIndices());
1978SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1981SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1983 getReassociationIndices());
1986RankedTensorType CollapseShapeOp::inferCollapsedType(
1987 RankedTensorType type, ArrayRef<ReassociationIndices> reassociation) {
1988 return inferCollapsedType(
1990 type.getContext(), reassociation)));
1996CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1997 ArrayRef<AffineMap> reassociation) {
1998 auto shape = type.getShape();
1999 SmallVector<int64_t, 4> newShape;
2000 newShape.reserve(reassociation.size());
2005 unsigned currentDim = 0;
2006 for (AffineMap m : reassociation) {
2007 unsigned dim = m.getNumResults();
2008 auto band = shape.slice(currentDim, dim);
2010 if (llvm::is_contained(band, ShapedType::kDynamic))
2011 size = ShapedType::kDynamic;
2013 for (
unsigned d = 0; d < dim; ++d)
2014 size *= shape[currentDim + d];
2015 newShape.push_back(size);
2019 return RankedTensorType::get(newShape, type.getElementType());
2022void CollapseShapeOp::build(OpBuilder &
b, OperationState &
result, Value src,
2023 ArrayRef<ReassociationIndices> reassociation,
2024 ArrayRef<NamedAttribute> attrs) {
2025 auto srcType = llvm::cast<RankedTensorType>(src.
getType());
2026 RankedTensorType collapsedType = inferCollapsedType(srcType, reassociation);
2028 RankedTensorType::get(collapsedType.getShape(), srcType.getElementType(),
2029 srcType.getEncoding());
2030 result.addAttribute(getReassociationAttrStrName(),
2032 build(
b,
result, resultType, src, attrs);
2035template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
2036 TensorReshapeOp, ExpandShapeOp>::value>
2038 RankedTensorType expandedType,
2039 RankedTensorType collapsedType) {
2041 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
2045 if (expandedType.hasStaticShape() && collapsedType.hasStaticShape()) {
2046 int64_t expandedNumElements = expandedType.getNumElements();
2047 int64_t collapsedNumElements = collapsedType.getNumElements();
2048 if (expandedNumElements != collapsedNumElements) {
2049 return op.emitOpError(
"number of elements must be preserved: ")
2050 << expandedNumElements <<
" != " << collapsedNumElements;
2054 auto maps = op.getReassociationMaps();
2055 RankedTensorType expectedType =
2056 CollapseShapeOp::inferCollapsedType(expandedType, maps);
2058 return op.emitOpError(
"expected collapsed type to be ")
2059 << expectedType <<
", but got " << collapsedType;
2063LogicalResult ExpandShapeOp::verify() {
2064 RankedTensorType srcType = getSrc().getType();
2065 RankedTensorType resultType = getResult().getType();
2067 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2068 return emitOpError(
"expected number of static shape dims to be equal to "
2069 "the output rank (")
2070 << resultType.getRank() <<
") but found "
2071 << getStaticOutputShape().size() <<
" inputs instead";
2073 if ((int64_t)getOutputShape().size() !=
2074 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2075 return emitOpError(
"mismatch in dynamic dims in output_shape and "
2076 "static_output_shape: static_output_shape has ")
2077 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2078 <<
" dynamic dims while output_shape has " << getOutputShape().size()
2084LogicalResult CollapseShapeOp::verify() {
2085 CollapseShapeOp op = *
this;
2086 if (llvm::any_of(op.getReassociationIndices(),
2088 return op.emitOpError(
"reassociation indices must not be empty");
2090 RankedTensorType srcType = op.getSrc().getType();
2091 RankedTensorType resultType = op.getResult().getType();
2099template <
typename TensorReshapeOp>
2100struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
2101 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2102 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2103 PatternRewriter &rewriter)
const override {
2104 DenseElementsAttr attr;
2110 reshapeOp.getResultType(), attr.
getRawData());
2117template <
typename TensorReshapeOp>
2118class FoldReshapeWithSplat :
public OpRewritePattern<TensorReshapeOp> {
2120 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2122 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2123 PatternRewriter &rewriter)
const override {
2124 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
2125 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
2129 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
2136template <
typename TensorReshapeOp>
2137struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
2138 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
2139 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
2140 PatternRewriter &rewriter)
const override {
2142 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
2146 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
2148 if (!shapedTy.hasStaticShape())
2152 fromElements.getElements());
2158struct FoldCollapseOfCastOp :
public OpRewritePattern<CollapseShapeOp> {
2159 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2161 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
2162 PatternRewriter &rewriter)
const override {
2163 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
2167 RankedTensorType srcType =
2168 llvm::cast<RankedTensorType>(castOp.getSource().getType());
2169 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
2170 srcType, collapseShapeOp.getReassociationMaps());
2172 if (newResultType == collapseShapeOp.getResultType()) {
2174 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
2177 auto newOp = CollapseShapeOp::create(rewriter, collapseShapeOp.getLoc(),
2178 newResultType, castOp.getSource(),
2179 collapseShapeOp.getReassociation());
2181 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
2191struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2192 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2194 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2195 PatternRewriter &rewriter)
const override {
2196 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2200 ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
2201 SmallVector<ReassociationIndices, 4> reassoc =
2202 expandOp.getReassociationIndices();
2204 SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
2205 SmallVector<Value> dynamicOutputShape;
2206 auto outputIt = expandOp.getOutputShape().begin();
2208 for (
const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
2209 for (uint64_t outDim : innerReassoc) {
2210 if (ShapedType::isStatic(newOutputShape[outDim]))
2217 Value val = *outputIt;
2219 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2220 dynamicOutputShape.push_back(val);
2226 newOutputShape[outDim] = cst.getSExtValue();
2228 dynamicOutputShape.push_back(val);
2234 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2238 SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
2239 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2240 for (
auto outDim : reassoc[inDim]) {
2241 auto ofr = newOutputShape[outDim];
2242 if (ShapedType::isDynamic(ofr)) {
2243 newInputShape[inDim] = ShapedType::kDynamic;
2246 newInputShape[inDim] *= ofr;
2250 SmallVector<OpFoldResult> outputOfr =
2252 auto inputType = RankedTensorType::get(
2253 newInputShape, expandOp.getSrcType().getElementType());
2254 auto outputType = RankedTensorType::get(
2255 newOutputShape, expandOp.getSrcType().getElementType());
2256 auto inputCast = CastOp::create(rewriter, expandOp.getLoc(), inputType,
2258 auto newExpand = ExpandShapeOp::create(
2259 rewriter, expandOp.getLoc(), outputType, inputCast.getResult(),
2260 expandOp.getReassociationIndices(), outputOfr);
2262 newExpand.getResult());
2268void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2269 MLIRContext *context) {
2271 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2272 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2273 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2274 FoldReshapeWithSplat<ExpandShapeOp>,
2275 FoldReshapeWithFromElements<ExpandShapeOp>>(context);
2278void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2279 MLIRContext *context) {
2281 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2282 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2283 tensor::DimOp, RankedTensorType>,
2284 FoldReshapeWithConstant<CollapseShapeOp>,
2285 FoldReshapeWithSplat<CollapseShapeOp>,
2286 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2290OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2292 adaptor.getOperands());
2295OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2297 adaptor.getOperands());
2304void ExtractSliceOp::getAsmResultNames(
2306 setNameFn(getResult(),
"extracted_slice");
2313ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2314 ArrayRef<int64_t> staticSizes) {
2318 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2319 sourceTensorType.getRank() &&
2320 "unexpected staticSizes not equal to rank of source");
2321 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2322 sourceTensorType.getEncoding());
2327ExtractSliceOp::inferResultType(RankedTensorType sourceTensorType,
2328 ArrayRef<OpFoldResult> sizes) {
2329 SmallVector<int64_t> staticSizes;
2332 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2333 sourceTensorType.getRank() &&
2334 "unexpected staticSizes not equal to rank of source");
2335 return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2336 sourceTensorType.getEncoding());
2347RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2348 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2349 ArrayRef<int64_t> sizes) {
2351 auto inferredType = llvm::cast<RankedTensorType>(
2352 inferResultType(sourceRankedTensorType, sizes));
2353 int rankDiff = inferredType.getRank() - desiredResultRank;
2355 auto shape = inferredType.getShape();
2356 llvm::SmallBitVector dimsToProject =
2358 SmallVector<int64_t> projectedShape;
2360 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2361 if (!dimsToProject.test(pos))
2362 projectedShape.push_back(shape[pos]);
2364 RankedTensorType::get(projectedShape, inferredType.getElementType());
2366 return inferredType;
2369RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2370 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2371 ArrayRef<OpFoldResult> sizes) {
2372 SmallVector<int64_t> staticSizes;
2373 SmallVector<Value> dynamicSizes;
2375 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2376 desiredResultRank, sourceRankedTensorType, staticSizes);
2381void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2382 RankedTensorType resultType, Value source,
2383 ArrayRef<OpFoldResult> offsets,
2384 ArrayRef<OpFoldResult> sizes,
2385 ArrayRef<OpFoldResult> strides,
2386 ArrayRef<NamedAttribute> attrs) {
2387 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2388 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2392 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2395 resultType = llvm::cast<RankedTensorType>(
2396 ExtractSliceOp::inferResultType(sourceRankedTensorType, staticSizes));
2398 result.addAttributes(attrs);
2399 build(
b,
result, resultType, source, dynamicOffsets, dynamicSizes,
2400 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2401 b.getDenseI64ArrayAttr(staticSizes),
2402 b.getDenseI64ArrayAttr(staticStrides));
2407void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2408 ArrayRef<OpFoldResult> offsets,
2409 ArrayRef<OpFoldResult> sizes,
2410 ArrayRef<OpFoldResult> strides,
2411 ArrayRef<NamedAttribute> attrs) {
2412 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2417void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2418 ArrayRef<Range> ranges,
2419 ArrayRef<NamedAttribute> attrs) {
2421 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2426void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2427 RankedTensorType resultType, Value source,
2429 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2430 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2431 offsets, [](Value v) -> OpFoldResult {
return v; });
2432 SmallVector<OpFoldResult> sizeValues =
2433 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2434 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2435 strides, [](Value v) -> OpFoldResult {
return v; });
2436 build(
b,
result, resultType, source, offsetValues, sizeValues, strideValues);
2440void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2442 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2443 build(
b,
result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2448 RankedTensorType expectedType) {
2453 return op->
emitError(
"expected rank to be smaller or equal to ")
2454 <<
"the other rank. ";
2456 return op->
emitError(
"expected type to be ")
2457 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2459 return op->
emitError(
"expected element type to be ")
2460 << expectedType.getElementType();
2462 llvm_unreachable(
"unexpected extract_slice op verification result");
2468void ExtractSliceOp::build(OpBuilder &
b, OperationState &
result,
2469 RankedTensorType resultType, Value source,
2470 ArrayRef<OpFoldResult> sizes,
2471 ArrayRef<NamedAttribute> attrs) {
2472 Attribute zeroIdxAttr =
b.getIndexAttr(0);
2473 Attribute oneIdxAttr =
b.getIndexAttr(1);
2474 SmallVector<OpFoldResult> readStrides(sizes.size(), oneIdxAttr);
2475 SmallVector<OpFoldResult> readOffsets(sizes.size(), zeroIdxAttr);
2476 build(
b,
result, resultType, source, readOffsets, sizes, readStrides, attrs);
2480LogicalResult ExtractSliceOp::verify() {
2481 RankedTensorType sourceType = getSourceType();
2484 RankedTensorType expectedType =
2485 ExtractSliceOp::inferResultType(sourceType,
getMixedSizes());
2493 sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
2494 getStaticStrides(),
true);
2496 return getOperation()->emitError(boundsResult.
errorMessage);
2501llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
2506ExtractSliceOp::rankReduceIfNeeded(OpBuilder &
b, Location loc, Value value,
2507 ArrayRef<int64_t> desiredShape) {
2508 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2509 assert(sourceTensorType &&
"not a ranked tensor type");
2510 auto sourceShape = sourceTensorType.getShape();
2511 if (sourceShape.equals(desiredShape))
2513 auto maybeRankReductionMask =
2515 if (!maybeRankReductionMask)
2519 RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
2522LogicalResult ExtractSliceOp::reifyResultShapes(
2524 reifiedReturnShapes.resize(1);
2525 reifiedReturnShapes[0].reserve(
getType().getRank());
2528 for (
const auto &size :
enumerate(mixedSizes)) {
2529 if (droppedDims.test(size.index()))
2531 reifiedReturnShapes[0].push_back(size.value());
2552class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2554 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2556 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2557 PatternRewriter &rewriter)
const override {
2559 if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
2560 return matchPattern(operand, matchConstantIndex());
2564 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2573 cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
2574 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2575 sliceOp.getStaticStrides());
2580 Location loc = sliceOp.getLoc();
2581 Value newResult = ExtractSliceOp::create(
2582 rewriter, loc, sliceOp.getType(), castOp.getSource(),
2583 sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(),
2584 sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
2585 sliceOp.getStaticStrides());
2594template <
typename IterTy,
typename ElemTy>
2595static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
2596 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2597 ArrayRef<int64_t> strides,
2598 llvm::SmallVectorImpl<ElemTy> *outValues) {
2599 assert(offsets.size() == sizes.size());
2600 assert(offsets.size() == strides.size());
2601 if (offsets.empty())
2604 int64_t offset = offsets.front();
2605 int64_t size = sizes.front();
2606 int64_t stride = strides.front();
2607 if (offsets.size() == 1) {
2608 for (int64_t i = 0; i < size; ++i, offset += stride)
2609 outValues->push_back(*(values + offset));
2614 for (int64_t i = 0; i < size; ++i, offset += stride) {
2615 auto begin = values + offset * counts.front();
2616 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2617 offsets.drop_front(), sizes.drop_front(),
2618 strides.drop_front(), outValues);
2625class ConstantOpExtractSliceFolder final
2626 :
public OpRewritePattern<ExtractSliceOp> {
2628 using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
2630 ConstantOpExtractSliceFolder(MLIRContext *context,
2632 : OpRewritePattern<ExtractSliceOp>(context),
2633 controlFn(std::move(controlFn)) {}
2635 LogicalResult matchAndRewrite(ExtractSliceOp op,
2636 PatternRewriter &rewriter)
const override {
2637 DenseElementsAttr attr;
2646 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2647 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2648 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2655 int64_t count = sourceType.getNumElements();
2660 auto offsets = op.getStaticOffsets();
2661 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2663 auto sizes = op.getStaticSizes();
2664 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2666 auto strides = op.getStaticStrides();
2667 if (llvm::is_contained(strides, ShapedType::kDynamic))
2671 SmallVector<int64_t> counts;
2672 ArrayRef<int64_t> shape = sourceType.getShape();
2673 counts.reserve(shape.size());
2674 for (int64_t v : shape) {
2676 counts.push_back(count);
2680 DenseElementsAttr newAttr;
2682 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2683 SmallVector<APInt> outValues;
2684 outValues.reserve(sourceType.getNumElements());
2685 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2686 elems.begin(), counts, offsets, sizes, strides, &outValues);
2688 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2689 SmallVector<APFloat> outValues;
2690 outValues.reserve(sourceType.getNumElements());
2691 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2692 elems.begin(), counts, offsets, sizes, strides, &outValues);
2715 patterns.add<ConstantOpExtractSliceFolder>(
patterns.getContext(), controlFn);
2725 RankedTensorType nonReducedType =
2726 ExtractSliceOp::inferResultType(op.getSourceType(), mixedSizes);
2730 llvm::SmallBitVector droppedDims = op.getDroppedDims();
2731 if (droppedDims.none())
2732 return nonReducedType;
2736 for (
auto i : llvm::seq<int64_t>(mixedSizes.size()))
2737 if (!droppedDims.test(i))
2738 targetShape.push_back(nonReducedType.getDimSize(i));
2740 return RankedTensorType::get(targetShape, nonReducedType.getElementType(),
2741 nonReducedType.getEncoding());
2748 ExtractSliceOp newOp) {
2751 replacement = tensor::CastOp::create(rewriter, op.getLoc(), op.getType(),
2757void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
2758 MLIRContext *context) {
2760 OpWithOffsetSizesAndStridesConstantArgumentFolder<
2761 ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
2762 ExtractSliceOpCastFolder>(context);
2768 ShapedType shapedType) {
2775 auto shape = shapedType.getShape();
2776 for (
auto it : llvm::zip(op.getMixedSizes(),
shape))
2790 auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
2793 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2794 insertOp.isSameAs(extractOp, isSame))
2795 return insertOp.getSource();
2800OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2801 if (OpFoldResult reshapedSource = reshapeConstantSource(
2802 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2804 return reshapedSource;
2805 if (getSourceType() ==
getType() &&
2807 return this->getSource();
2811 return OpFoldResult();
2816 auto rankedTensorType = llvm::cast<RankedTensorType>(
tensor.getType());
2817 unsigned rank = rankedTensorType.getRank();
2821 return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType,
tensor,
2822 offsets, sizes, strides);
2829void InsertSliceOp::getAsmResultNames(
2831 setNameFn(getResult(),
"inserted_slice");
2845 result.addAttributes(attrs);
2846 build(
b,
result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2847 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
2848 b.getDenseI64ArrayAttr(staticSizes),
2849 b.getDenseI64ArrayAttr(staticStrides));
2854void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2855 Value dest, ArrayRef<Range> ranges,
2856 ArrayRef<NamedAttribute> attrs) {
2858 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
2862void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
2864 ValueRange strides, ArrayRef<NamedAttribute> attrs) {
2865 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
2866 offsets, [](Value v) -> OpFoldResult {
return v; });
2867 SmallVector<OpFoldResult> sizeValues =
2868 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
2869 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2870 strides, [](Value v) -> OpFoldResult {
return v; });
2871 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
2877 RankedTensorType srcType, RankedTensorType dstType,
2882 RankedTensorType expected =
2883 ExtractSliceOp::inferResultType(dstType, staticSizes);
2885 *expectedType = expected;
2890LogicalResult InsertSliceOp::verify() {
2892 RankedTensorType expectedType;
2895 getStaticSizes(), getStaticStrides(), &expectedType);
2902 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
2903 getStaticStrides(),
true);
2905 return getOperation()->emitError(boundsResult.
errorMessage);
2928 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2931 if (!prevInsertOp ||
2932 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2933 !prevInsertOp.isSameAs(insertOp, isSame))
2936 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2948 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2951 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2952 !extractOp.isSameAs(insertOp, isSame))
2955 return extractOp.getSource();
2958OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
2959 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2960 getSourceType() ==
getType() &&
2962 return this->getSource();
2969 return OpFoldResult();
2972LogicalResult InsertSliceOp::reifyResultShapes(
2974 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
2983template <
typename InsertOpTy>
2984class InsertSliceOpConstantArgumentFolder final
2985 :
public OpRewritePattern<InsertOpTy> {
2987 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
2989 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2990 PatternRewriter &rewriter)
const override {
2991 SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
2992 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
2993 SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
3002 SliceBoundsVerificationResult sliceResult =
3004 mixedOffsets, mixedSizes, mixedStrides);
3009 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
3010 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
3012 Value toInsert = insertSliceOp.getSource();
3013 if (sourceType != insertSliceOp.getSourceType()) {
3014 OpBuilder::InsertionGuard g(rewriter);
3018 if (isa<InParallelOpInterface>(insertSliceOp->getParentOp()))
3020 toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3021 sourceType, toInsert);
3024 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
3025 mixedSizes, mixedStrides);
3050template <
typename InsertOpTy>
3051struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
3052 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3054 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3055 PatternRewriter &rewriter)
const override {
3056 if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
3057 return matchPattern(operand, matchConstantIndex());
3061 auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
3064 return std::nullopt;
3065 return castOp.getSource();
3067 std::optional<Value> sourceCastSource =
3068 getSourceOfCastOp(insertSliceOp.getSource());
3069 std::optional<Value> destCastSource =
3070 getSourceOfCastOp(insertSliceOp.getDest());
3071 if (!sourceCastSource && !destCastSource)
3075 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
3076 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
3077 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
3078 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
3079 if (!srcType || !dstType)
3085 SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
3087 staticSizes, srcType.getShape(),
true);
3088 if (!rankReductionMask.has_value())
3095 SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
3096 int64_t rankReducedIdx = 0;
3097 for (
auto [idx, size] :
enumerate(staticSizes)) {
3098 if (!rankReductionMask.value().contains(idx) &&
3099 !srcType.isDynamicDim(rankReducedIdx)) {
3101 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
3102 size = srcType.getDimSize(rankReducedIdx++);
3108 staticSizes, insertSliceOp.getStaticStrides()) !=
3109 SliceVerificationResult::Success)
3111 SliceBoundsVerificationResult sliceResult =
3113 mixedSizes, insertSliceOp.getMixedStrides());
3118 InsertOpTy::create(rewriter, insertSliceOp.getLoc(), src, dst,
3119 insertSliceOp.getMixedOffsets(), mixedSizes,
3120 insertSliceOp.getMixedStrides());
3123 bool isParallelInsert =
3124 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
3125 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
3126 replacement = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3127 insertSliceOp.getDestType(),
3156template <
typename InsertOpTy>
3157struct InsertSliceOpSourceCastInserter final
3158 :
public OpRewritePattern<InsertOpTy> {
3159 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
3161 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
3162 PatternRewriter &rewriter)
const override {
3163 RankedTensorType srcType = insertSliceOp.getSourceType();
3164 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
3166 SmallVector<int64_t> newSrcShape(srcType.getShape());
3167 for (int64_t i = 0; i < srcType.getRank(); ++i) {
3168 if (std::optional<int64_t> constInt =
3173 newSrcShape[i] = *constInt;
3179 RankedTensorType newSrcType = RankedTensorType::get(
3180 newSrcShape, srcType.getElementType(), srcType.getEncoding());
3181 if (srcType == newSrcType ||
3183 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
3191 OpBuilder::InsertionGuard g(rewriter);
3195 if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
3197 Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
3198 newSrcType, insertSliceOp.getSource());
3200 insertSliceOp, cast, insertSliceOp.getDest(),
3201 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
3202 insertSliceOp.getMixedStrides());
3208llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
3212void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3213 MLIRContext *context) {
3214 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3215 InsertSliceOpCastFolder<InsertSliceOp>,
3216 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3223 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3224 unsigned rank = rankedTensorType.getRank();
3228 return b.createOrFold<tensor::InsertSliceOp>(loc,
tensor, dest, offsets,
3237 setNameFn(getResult(),
"padded");
3240LogicalResult PadOp::verify() {
3241 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3242 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3244 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3245 if (!expectedType) {
3246 return emitError(
"failed to infer expectedType from sourceType ")
3247 << sourceType <<
", specified resultType is " << resultType;
3249 if (resultType.getRank() != expectedType.getRank()) {
3251 << resultType <<
" does not match the inferred type "
3254 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3255 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3257 if (expectedType.isDynamicDim(i))
3260 << resultType <<
" does not match the inferred type "
3267LogicalResult PadOp::verifyRegions() {
3268 auto ®ion = getRegion();
3269 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3270 Block &block = region.front();
3272 return emitError(
"expected the block to have ") << rank <<
" arguments";
3276 if (!en.value().isIndex())
3278 << (en.index() + 1) <<
" to be an index";
3283 if (yieldOp.getValue().getType() !=
3285 return emitOpError(
"expected yield type to match shape element type");
3290RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3291 ArrayRef<int64_t> staticLow,
3292 ArrayRef<int64_t> staticHigh,
3293 ArrayRef<int64_t> resultShape) {
3294 unsigned rank = sourceType.getRank();
3295 if (staticLow.size() != rank)
3296 return RankedTensorType();
3297 if (staticHigh.size() != rank)
3298 return RankedTensorType();
3299 if (!resultShape.empty() && resultShape.size() != rank)
3300 return RankedTensorType();
3302 SmallVector<int64_t, 4> inferredShape;
3303 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3304 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3305 staticHigh[i] == ShapedType::kDynamic) {
3306 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3309 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3310 assert((resultShape.empty() || size == resultShape[i] ||
3311 resultShape[i] == ShapedType::kDynamic) &&
3312 "mismatch between inferred shape and result shape");
3313 inferredShape.push_back(size);
3317 return RankedTensorType::get(inferredShape, sourceType.getElementType());
3320void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3321 Value source, ArrayRef<int64_t> staticLow,
3323 bool nofold, ArrayRef<NamedAttribute> attrs) {
3324 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3326 resultType = inferResultType(sourceType, staticLow, staticHigh);
3327 result.addAttributes(attrs);
3328 build(
b,
result, resultType, source, low, high,
3329 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3330 nofold ?
b.getUnitAttr() : UnitAttr());
3333void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3335 ArrayRef<NamedAttribute> attrs) {
3336 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3337 unsigned rank = sourceType.getRank();
3338 SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
3339 build(
b,
result, resultType, source, staticVector, staticVector, low, high,
3343void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3344 Value source, ArrayRef<OpFoldResult> low,
3345 ArrayRef<OpFoldResult> high,
bool nofold,
3346 ArrayRef<NamedAttribute> attrs) {
3347 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3348 SmallVector<Value, 4> dynamicLow, dynamicHigh;
3349 SmallVector<int64_t, 4> staticLow, staticHigh;
3357 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3359 assert(llvm::isa<RankedTensorType>(resultType));
3360 result.addAttributes(attrs);
3361 build(
b,
result, resultType, source, dynamicLow, dynamicHigh,
3362 b.getDenseI64ArrayAttr(staticLow),
b.getDenseI64ArrayAttr(staticHigh),
3363 nofold ?
b.getUnitAttr() : UnitAttr());
3366void PadOp::build(OpBuilder &
b, OperationState &
result, Type resultType,
3367 Value source, ArrayRef<OpFoldResult> low,
3368 ArrayRef<OpFoldResult> high, Value constantPadValue,
3369 bool nofold, ArrayRef<NamedAttribute> attrs) {
3370 build(
b,
result, resultType, source, low, high, nofold, attrs);
3373 Region *region =
result.regions[0].get();
3374 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3375 SmallVector<Type> blockArgTypes(sourceRank,
b.getIndexType());
3376 SmallVector<Location> blockArgLocs(sourceRank,
result.location);
3380 OpBuilder::InsertionGuard guard(
b);
3381 b.createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3382 tensor::YieldOp::create(
b,
result.location, constantPadValue);
3385llvm::SmallBitVector PadOp::getPaddedDims() {
3386 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3387 auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
3388 for (
const auto &en :
enumerate(paddingWidths))
3390 paddedDims.set(en.index());
3392 extractPaddedDims(getMixedLowPad());
3393 extractPaddedDims(getMixedHighPad());
3400struct FoldStaticZeroPadding :
public OpRewritePattern<PadOp> {
3401 using OpRewritePattern<PadOp>::OpRewritePattern;
3403 LogicalResult matchAndRewrite(PadOp padTensorOp,
3404 PatternRewriter &rewriter)
const override {
3405 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3407 if (padTensorOp.getNofold())
3410 padTensorOp, padTensorOp.getResult().
getType(),
3411 padTensorOp.getSource());
3417struct FoldSourceTensorCast :
public OpRewritePattern<PadOp> {
3418 using OpRewritePattern<PadOp>::OpRewritePattern;
3420 LogicalResult matchAndRewrite(PadOp padTensorOp,
3421 PatternRewriter &rewriter)
const override {
3422 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3426 auto newResultType = PadOp::inferResultType(
3427 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3428 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3429 padTensorOp.getResultType().getShape());
3431 if (newResultType == padTensorOp.getResultType()) {
3433 padTensorOp.getSourceMutable().assign(castOp.getSource());
3436 auto newOp = PadOp::create(
3437 rewriter, padTensorOp->getLoc(), newResultType,
3438 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3439 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3440 padTensorOp.getHigh(), padTensorOp.getNofold(),
3443 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3446 padTensorOp, padTensorOp.getResultType(), newOp);
3454struct FoldTargetTensorCast :
public OpRewritePattern<PadOp> {
3455 using OpRewritePattern<PadOp>::OpRewritePattern;
3457 LogicalResult matchAndRewrite(PadOp padTensorOp,
3458 PatternRewriter &rewriter)
const override {
3459 if (!padTensorOp.getResult().hasOneUse())
3462 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3466 tensorCastOp.getDest().getType()))
3469 auto replacementOp = PadOp::create(
3470 rewriter, padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3471 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3472 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3473 padTensorOp.getHigh(), padTensorOp.getNofold(),
3475 replacementOp.getRegion().takeBody(padTensorOp.getRegion());
3477 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3478 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3518struct FoldOrthogonalPaddings :
public OpRewritePattern<PadOp> {
3519 using OpRewritePattern<PadOp>::OpRewritePattern;
3521 LogicalResult matchAndRewrite(PadOp padOp,
3522 PatternRewriter &rewriter)
const override {
3523 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3526 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3527 if (!outerPadOp || outerPadOp.getNofold())
3529 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3534 int64_t rank = padOp.getSourceType().getRank();
3535 if (outerSliceOp.getSourceType().getRank() != rank) {
3537 "cannot fold rank-reducing chain");
3541 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3543 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3547 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3549 "cannot fold PadOps with low padding");
3553 Attribute innerAttr, outerAttr;
3554 Value innerValue = padOp.getConstantPaddingValue();
3555 Value outerValue = outerPadOp.getConstantPaddingValue();
3556 if (!innerValue || !outerValue ||
3559 innerAttr != outerAttr) {
3561 padOp,
"cannot fold PadOps with different padding values");
3565 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3566 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3567 if (innerDims.anyCommon(outerDims)) {
3569 padOp,
"cannot fold PadOps with common padding dimensions");
3577 SmallVector<OpFoldResult> newOffsets(rank, rewriter.
getIndexAttr(0));
3579 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3580 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3581 if (!innerDims.test(en.index()) &&
3583 en.value() = outerOffset;
3586 if (!outerDims.test(en.index()) &&
3588 en.value() = innerOffset;
3592 padOp,
"cannot find zero-offset and zero-padding pair");
3600 SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
3602 if (!outerDims.test(en.index()))
3604 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3605 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3606 assert(ShapedType::isStatic(sourceSize) &&
3607 "expected padded dimension to have a static size");
3610 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3611 "match the size of the outer padding");
3613 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3617 SmallVector<OpFoldResult> newHighPad(rank, rewriter.
getIndexAttr(0));
3619 if (innerDims.test(en.index()))
3620 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3621 if (outerDims.test(en.index()))
3622 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3627 auto newSliceOp = ExtractSliceOp::create(
3628 rewriter, padOp.getLoc(), outerSliceOp.getSource(), newOffsets,
3629 newSizes, innerSliceOp.getMixedStrides());
3630 auto newPadOp = PadOp::create(
3631 rewriter, padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3632 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3635 newPadOp.getRegion().begin());
3636 rewriter.
replaceOp(padOp, newPadOp.getResult());
3641struct FoldStaticPadding :
public OpRewritePattern<PadOp> {
3642 using OpRewritePattern<PadOp>::OpRewritePattern;
3644 LogicalResult matchAndRewrite(PadOp padTensorOp,
3645 PatternRewriter &rewriter)
const override {
3646 Value input = padTensorOp.getSource();
3647 if (!llvm::isa<RankedTensorType>(input.
getType()))
3649 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3650 auto inputRank = inputDims.size();
3652 auto oldResultType =
3653 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3657 auto outputDims = oldResultType.getShape();
3660 SmallVector<int64_t> constOperandsLow;
3661 SmallVector<Value> newLows;
3662 for (
auto operand : padTensorOp.getLow()) {
3665 constOperandsLow.push_back(ShapedType::kDynamic);
3666 newLows.push_back(operand);
3669 constOperandsLow.push_back(intOp.getExtValue());
3671 SmallVector<int64_t> constOperandsHigh;
3672 SmallVector<Value> newHighs;
3673 for (
auto operand : padTensorOp.getHigh()) {
3676 constOperandsHigh.push_back(ShapedType::kDynamic);
3677 newHighs.push_back(operand);
3680 constOperandsHigh.push_back(intOp.getExtValue());
3683 SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
3684 SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
3687 if (inputDims.size() != outputDims.size() ||
3688 inputDims.size() != constLow.size() ||
3689 inputDims.size() != constHigh.size())
3694 for (
size_t i = 0; i < inputRank; i++) {
3695 if (constLow[i] == ShapedType::kDynamic)
3696 constLow[i] = constOperandsLow[lowCount++];
3697 if (constHigh[i] == ShapedType::kDynamic)
3698 constHigh[i] = constOperandsHigh[highCount++];
3701 auto staticLow = ArrayRef<int64_t>(constLow);
3702 auto staticHigh = ArrayRef<int64_t>(constHigh);
3705 SmallVector<int64_t> newOutDims;
3706 for (
size_t i = 0; i < inputRank; i++) {
3707 if (outputDims[i] == ShapedType::kDynamic) {
3708 newOutDims.push_back(
3709 (staticLow[i] == ShapedType::kDynamic ||
3710 staticHigh[i] == ShapedType::kDynamic ||
3711 inputDims[i] == ShapedType::kDynamic
3712 ? ShapedType::kDynamic
3713 : inputDims[i] + staticLow[i] + staticHigh[i]));
3715 newOutDims.push_back(outputDims[i]);
3719 if (SmallVector<int64_t>(outputDims) == newOutDims ||
3720 llvm::all_of(newOutDims,
3721 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3725 auto newResultType = RankedTensorType::get(
3726 newOutDims, padTensorOp.getType().getElementType());
3727 auto newOp = PadOp::create(
3728 rewriter, padTensorOp->getLoc(), newResultType, input, staticLow,
3729 staticHigh, newLows, newHighs, padTensorOp.getNofold(),
3733 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3761struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3762 using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
3764 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3765 PatternRewriter &rewriter)
const override {
3766 if (padOp.getNofold()) {
3770 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3771 if (!producerPad || producerPad.getNofold()) {
3773 padOp,
"producer is not a foldable tensor.pad op");
3777 Value consumerPadValue = padOp.getConstantPaddingValue();
3778 Value producerPadValue = producerPad.getConstantPaddingValue();
3779 if (!consumerPadValue || !producerPadValue ||
3780 consumerPadValue != producerPadValue) {
3783 "cannot fold PadOps with different or non-constant padding values");
3786 Location loc = padOp.getLoc();
3791 auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
3792 ArrayRef<OpFoldResult> producerPaddings) {
3793 SmallVector<OpFoldResult> sumPaddings;
3794 for (
auto [consumerIndex, producerIndex] :
3795 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3796 sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
3797 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3802 SmallVector<OpFoldResult> newHighPad =
3803 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3804 SmallVector<OpFoldResult> newLowPad =
3805 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3807 auto newPadOp = tensor::PadOp::create(
3808 rewriter, padOp.getLoc(), padOp.getResultType(),
3809 producerPad.getSource(), newLowPad, newHighPad, padOp.getNofold(),
3812 newPadOp.getRegion().begin());
3813 rewriter.
replaceOp(padOp, newPadOp.getResult());
3821PadOp::reifyResultShapes(OpBuilder &
b,
3823 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
3824 SmallVector<OpFoldResult> lp = getMixedLowPad();
3825 SmallVector<OpFoldResult> hp = getMixedHighPad();
3826 for (int64_t i = 0; i < getResultType().getRank(); ++i) {
3827 if (!
getType().isDynamicDim(i)) {
3828 reifiedReturnShapes[0][i] =
b.getIndexAttr(
getType().getDimSize(i));
3831 Location loc = getLoc();
3832 Value dim =
b.createOrFold<tensor::DimOp>(
3835 AffineExpr d0, d1, d2;
3838 b, loc, {d0 + d1 + d2}, {dim, lp[i], hp[i]});
3843void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3844 MLIRContext *context) {
3845 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3846 FoldOrthogonalPaddings, FoldStaticPadding,
3847 FoldConsecutiveConstantPadding>(context);
3859Value PadOp::getConstantPaddingValue() {
3860 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3863 Value padValue = yieldOp.getValue();
3874OpFoldResult PadOp::fold(FoldAdaptor) {
3875 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3885OpResult ParallelInsertSliceOp::getTiedOpResult() {
3886 InParallelOpInterface parallelCombiningParent = getParallelCombiningParent();
3887 for (
const auto &it :
3888 llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
3889 Operation &nextOp = it.value();
3890 if (&nextOp == getOperation())
3891 return parallelCombiningParent.getParentResult(it.index());
3893 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3897void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3898 Value source, Value dest,
3899 ArrayRef<OpFoldResult> offsets,
3900 ArrayRef<OpFoldResult> sizes,
3901 ArrayRef<OpFoldResult> strides,
3902 ArrayRef<NamedAttribute> attrs) {
3903 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3904 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3908 result.addAttributes(attrs);
3909 build(
b,
result, {}, source, dest, dynamicOffsets, dynamicSizes,
3910 dynamicStrides,
b.getDenseI64ArrayAttr(staticOffsets),
3911 b.getDenseI64ArrayAttr(staticSizes),
3912 b.getDenseI64ArrayAttr(staticStrides));
3917void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3918 Value source, Value dest,
3919 ArrayRef<Range> ranges,
3920 ArrayRef<NamedAttribute> attrs) {
3922 build(
b,
result, source, dest, offsets, sizes, strides, attrs);
3926void ParallelInsertSliceOp::build(OpBuilder &
b, OperationState &
result,
3927 Value source, Value dest,
ValueRange offsets,
3929 ArrayRef<NamedAttribute> attrs) {
3930 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3931 offsets, [](Value v) -> OpFoldResult {
return v; });
3932 SmallVector<OpFoldResult> sizeValues =
3933 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult {
return v; });
3934 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3935 strides, [](Value v) -> OpFoldResult {
return v; });
3936 build(
b,
result, source, dest, offsetValues, sizeValues, strideValues);
3941void InsertSliceOp::build(OpBuilder &
b, OperationState &
result, Value source,
3942 Value dest, ArrayRef<OpFoldResult> sizes,
3943 ArrayRef<NamedAttribute> attrs) {
3944 Attribute zeroIdxAttr =
b.getIndexAttr(0);
3945 Attribute oneIdxAttr =
b.getIndexAttr(1);
3946 SmallVector<OpFoldResult> writeStrides(sizes.size(), oneIdxAttr);
3947 SmallVector<OpFoldResult> writeOffsets(sizes.size(), zeroIdxAttr);
3948 build(
b,
result, source, dest, writeOffsets, sizes, writeStrides, attrs);
3951LogicalResult ParallelInsertSliceOp::verify() {
3952 if (!isa<InParallelOpInterface>(getOperation()->getParentOp()))
3953 return this->
emitError(
"expected InParallelOpInterface parent, got:")
3954 << *(getOperation()->getParentOp());
3957 RankedTensorType expectedType;
3960 getStaticSizes(), getStaticStrides(), &expectedType);
3967 getDestType().
getShape(), getStaticOffsets(), getStaticSizes(),
3968 getStaticStrides(),
true);
3970 return getOperation()->emitError(boundsResult.
errorMessage);
3975void ParallelInsertSliceOp::getCanonicalizationPatterns(
3976 RewritePatternSet &results, MLIRContext *context) {
3977 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3978 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3979 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3982llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
3987MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
3988 return getDestMutable();
3991Operation *ParallelInsertSliceOp::getIteratingParent() {
3993 if (
auto combiningOp =
3994 dyn_cast<InParallelOpInterface>(getOperation()->getParentOp()))
3995 return combiningOp->getParentOp();
4003void ScatterOp::getAsmResultNames(
4005 setNameFn(getResult(),
"scatter");
4008LogicalResult ScatterOp::verify() {
4009 int64_t destRank = getDestType().getRank();
4010 ArrayRef<int64_t> scatterDims = getScatterDims();
4012 getIndicesType().
getShape(), destRank,
4013 "scatter",
"dest")))
4017 return emitOpError(
"requires 'unique' attribute to be set");
4024 RankedTensorType expectedSourceType = GatherOp::inferResultType(
4025 getDestType(), getIndicesType(), scatterDims,
false);
4026 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
4027 getDestType(), getIndicesType(), scatterDims,
true);
4028 if (getSourceType() != expectedSourceType &&
4029 getSourceType() != expectedRankReducedSourceType) {
4033 << expectedSourceType <<
" or its rank-reduced variant "
4034 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
4045void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4046 Type aggregateType,
ValueRange dynamicSizes) {
4047 build(builder,
result, aggregateType, element, dynamicSizes);
4050void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4051 ArrayRef<int64_t> staticShape,
ValueRange dynamicSizes) {
4052 auto aggregateType = RankedTensorType::get(staticShape, element.
getType());
4053 build(builder,
result, aggregateType, element, dynamicSizes);
4056void SplatOp::build(OpBuilder &builder, OperationState &
result, Value element,
4057 ArrayRef<OpFoldResult> sizes) {
4058 SmallVector<int64_t> staticShape;
4059 SmallVector<Value> dynamicSizes;
4061 build(builder,
result, element, staticShape, dynamicSizes);
4064void SplatOp::getAsmResultNames(
4066 setNameFn(getResult(),
"splat");
4069LogicalResult SplatOp::verify() {
4075SplatOp::reifyResultShapes(OpBuilder &builder,
4077 reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(
getType().getRank()));
4079 for (int64_t i = 0; i <
getType().getRank(); ++i) {
4080 if (
getType().isDynamicDim(i)) {
4089OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
4090 auto constOperand = adaptor.getInput();
4091 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
4095 if (!
getType().hasStaticShape())
4110 if (isa<InsertSliceOp>(op.getOperation()) ||
4111 isa<LoopLikeOpInterface>(op.getOperation()))
4144 isa<linalg::RelayoutOpInterface>(*op))
4152 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4155 replacements.reserve(newOp->getNumResults());
4156 for (
auto [oldResult, newResult] :
4157 llvm::zip(op->getResults(), newOp->getResults())) {
4158 if (newResult.getType() != oldResult.getType()) {
4159 replacements.push_back(tensor::CastOp::create(
4160 rewriter, op->getLoc(), oldResult.
getType(), newResult));
4162 replacements.push_back(newResult);
4175void TensorDialect::getCanonicalizationPatterns(
4176 RewritePatternSet &results)
const {
4184#define GET_OP_CLASSES
4185#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
std::string join(const Ts &...args)
Helper function to concatenate arguments into a std::string.
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static Value foldExtractAfterInsert(ExtractOp extractOp)
If we have an ExtractOp consuming an InsertOp with the same indices, we can return the InsertOp's sca...
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
Attributes are known-constant values of operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineSymbolExpr(unsigned position)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
AffineExpr getAffineDimExpr(unsigned position)
AffineMap getConstantAffineMap(int64_t val)
Returns a single constant result affine map with 0 dimensions and 0 symbols.
MLIRContext * getContext() const
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Operation is the basic unit of execution within MLIR.
MutableArrayRef< OpOperand > getOpOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns)
Patterns to fold extracts of a collapse_shaped tensor to an extract of the source tensor.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SmallVector< int64_t, 2 > ReassociationIndices
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.