28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/Support/MathExtras.h"
39 using llvm::divideCeilSigned;
40 using llvm::divideFloorSigned;
48 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
50 if (complex::ConstantOp::isBuildableWith(value, type))
51 return builder.
create<complex::ConstantOp>(loc, type,
52 llvm::cast<ArrayAttr>(value));
58 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
60 if (tensorType.isDynamicDim(dim))
61 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
68 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
70 for (int64_t i = 0; i < tensorType.getRank(); ++i)
77 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
78 assert(tensorType &&
"expected tensor type");
82 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
84 return destOp.getTiedOpOperand(opResult)->get();
92 if (!tensorType.hasStaticShape()) {
100 for (int64_t sz : tensorType.getShape())
106 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
114 if (llvm::isa<TensorType>(opResult.getType())) {
116 if (failed(destination))
118 result.push_back(*destination);
125 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
126 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
127 return rtp1.getShape() == rtp2.getShape() &&
128 rtp1.getElementType() == rtp2.getElementType();
138 llvm::SmallBitVector droppedDims(mixedSizes.size());
139 int64_t shapePos = reducedShape.size() - 1;
141 for (
const auto &size :
enumerate(llvm::reverse(mixedSizes))) {
142 size_t idx = mixedSizes.size() - size.index() - 1;
144 bool isStaticUnitSize =
146 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
151 assert(isStaticUnitSize &&
"expected unit dim");
152 droppedDims.set(idx);
157 if (!isStaticUnitSize) {
163 if (reducedShape[shapePos] == 1) {
169 droppedDims.set(idx);
172 assert(shapePos < 0 &&
"dimension mismatch");
179 static RankedTensorType
183 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
184 "incorrect number of dynamic sizes");
188 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
189 if (type.isDynamicDim(i)) {
190 Value dynamicSize = dynamicSizes[ctr++];
192 if (cst.has_value()) {
194 if (cst.value() < 0) {
195 foldedDynamicSizes.push_back(dynamicSize);
198 staticShape[i] = *cst;
200 foldedDynamicSizes.push_back(dynamicSize);
214 if (inputs.size() != 1 || outputs.size() != 1)
216 Type a = inputs.front(), b = outputs.front();
217 auto aT = dyn_cast<TensorType>(a);
218 auto bT = dyn_cast<TensorType>(b);
222 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
235 LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
237 auto tensorBitcastOperand =
238 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
239 if (!tensorBitcastOperand)
242 auto resultType = cast<TensorType>(tensorBitcast.getType());
243 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
244 tensorBitcastOperand.getOperand());
253 results.
add<ChainedTensorBitcast>(context);
261 setNameFn(getResult(),
"cast");
267 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
268 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
271 if (!sourceType || !targetType)
275 if (sourceType.getElementType() != targetType.getElementType())
279 if (sourceType.getRank() != targetType.getRank())
283 if (sourceType.getEncoding() != targetType.getEncoding())
287 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
288 if (!ShapedType::isDynamic(std::get<0>(t)) &&
289 ShapedType::isDynamic(std::get<1>(t)))
325 castOp.getSource().getType());
360 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
362 operand.set(castOp.getOperand());
366 return success(folded);
370 if (inputs.size() != 1 || outputs.size() != 1)
372 Type a = inputs.front(), b = outputs.front();
373 auto aT = llvm::dyn_cast<TensorType>(a);
374 auto bT = llvm::dyn_cast<TensorType>(b);
378 if (aT.getElementType() != bT.getElementType())
394 int64_t rank = one.getRank();
395 if (rank != two.getRank())
400 for (int64_t i = 0; i < rank; ++i) {
401 if (one.isDynamicDim(i)) {
402 join.push_back(two.getDimSize(i));
405 if (two.isDynamicDim(i)) {
406 join.push_back(one.getDimSize(i));
409 if (one.getDimSize(i) != two.getDimSize(i))
411 join.push_back(one.getDimSize(i));
423 LogicalResult matchAndRewrite(CastOp tensorCast,
425 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
427 if (!tensorCastOperand)
431 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
432 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
433 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
447 auto newJoin =
joinShapes(sourceType, resultType);
448 if (firstJoin != newJoin)
451 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
452 tensorCastOperand.getOperand());
472 LogicalResult matchAndRewrite(CastOp tensorCast,
474 auto extractOperand =
475 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
478 auto rankedResultType =
479 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
480 if (!rankedResultType)
484 rankedResultType.getShape() ==
485 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
491 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
493 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
494 if (dimMask && dimMask->count(i))
496 int64_t dim = rankedResultType.getShape()[dimIndex++];
497 if (ShapedType::isDynamic(dim))
499 sizes[i] = rewriter.getIndexAttr(dim);
502 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
503 tensorCast, rankedResultType, extractOperand.getSource(),
504 extractOperand.getMixedOffsets(), sizes,
505 extractOperand.getMixedStrides());
514 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
521 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
522 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
524 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
525 return llvm::cast<RankedTensorType>(type);
527 int64_t concatRank = tensorTypes[0].getRank();
530 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
533 for (int64_t i = 0, e = concatRank; i < e; ++i) {
537 for (
auto tensorType : tensorTypes)
542 for (
auto tensorType : tensorTypes)
545 sizes[dim] = concatSize.asInteger();
551 FailureOr<RankedTensorType> resultType =
552 inferResultType(dim, inputs.
getTypes());
553 assert(succeeded(resultType) &&
"failed to infer concatenation result type");
554 build(builder, result, *resultType, dim, inputs);
558 if (getInputs().size() < 1)
559 return emitOpError(
"requires at least one input");
562 for (
auto input : getInputs())
563 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
565 RankedTensorType resultType = getResultType();
566 int64_t resultRank = getRank();
567 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
568 return type.getRank() != resultRank;
570 return emitOpError(
"rank of concatenated inputs must match result rank");
572 Type resultElementType = resultType.getElementType();
573 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
574 return type.getElementType() != resultElementType;
576 return emitOpError(
"inputs and result element type must match");
578 int64_t dim = getDim();
579 if (dim >= resultRank)
580 return emitOpError(
"concatenation dim must be less than the tensor rank");
583 for (int64_t i = 0, e = resultRank; i < e; ++i) {
587 for (
auto tensorType : inputTypes) {
588 FailureOr<SaturatedInteger> maybeSize =
590 if (failed(maybeSize))
591 return emitOpError(
"static concatenation size mismatch along ")
592 <<
"non-concatenated dimension " << i;
598 for (
auto tensorType : inputTypes)
601 sizes[dim] = concatSize.asInteger();
602 auto inferredResultType =
605 for (
auto [inferredSize, actualSize] :
606 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
607 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
608 ShapedType::isDynamic(actualSize);
609 if (!hasDynamic && inferredSize != actualSize)
610 return emitOpError(
"result type ")
611 << resultType <<
"does not match inferred shape "
612 << inferredResultType <<
" static sizes";
618 FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(
OpBuilder &builder) {
619 size_t numInputs = getInputs().size();
620 uint64_t concatDim = getDim();
623 inputShapes.reserve(numInputs);
625 concatOffsets.reserve(numInputs);
636 outputShape = inputShape;
637 concatOffsets.push_back(zero);
639 concatOffsets.push_back(outputShape[concatDim]);
641 builder, loc, addExpr,
642 {outputShape[concatDim], inputShape[concatDim]});
644 inputShapes.emplace_back(std::move(inputShape));
647 Value replacement = builder.
create<tensor::EmptyOp>(
648 loc, outputShape,
getType().getElementType());
650 int64_t rank =
getType().getRank();
655 offsets[concatDim] = concatOffsets[index];
656 auto insertSlice = builder.
create<tensor::InsertSliceOp>(
657 loc, input, replacement, offsets, inputShapes[index], strides);
660 if (replacement.getType() !=
getType()) {
661 replacement = builder.
create<tensor::CastOp>(loc,
getType(), replacement);
670 int64_t dim = getDim();
671 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
673 Value init = inputs[0];
674 int64_t rank =
getType().getRank();
681 for (int64_t i = 0; i < rank; ++i) {
684 if (!
getType().isDynamicDim(i)) {
686 }
else if (!inferredResultType.isDynamicDim(i)) {
689 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
691 reifiedReturnShapes[0][i] =
692 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
696 if (
getType().isDynamicDim(dim)) {
704 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
712 reifiedReturnShapes[0][dim] =
718 void ConcatOp::getAsmResultNames(
720 setNameFn(getResult(),
"concat");
725 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
735 LogicalResult matchAndRewrite(ConcatOp concatOp,
737 if (concatOp.getInputs().size() != 1)
740 concatOp.getInputs()[0]);
748 results.
add<SingleInputConcatOp>(context);
756 setNameFn(getResult(),
"dim");
762 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
763 build(builder, result, source, indexValue);
766 std::optional<int64_t> DimOp::getConstantIndex() {
775 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().
getType());
776 if (!rankedSourceType)
787 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
792 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().
getType());
798 int64_t indexVal = index.getInt();
799 if (indexVal < 0 || indexVal >= tensorType.getRank())
803 if (!tensorType.isDynamicDim(index.getInt())) {
805 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
808 Operation *definingOp = getSource().getDefiningOp();
811 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
813 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
816 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
819 auto dynExtents = fromElements.getDynamicExtents().begin();
820 for (
auto dim : resultType.getShape().take_front(index.getInt()))
821 if (ShapedType::isDynamic(dim))
824 return Value{*dynExtents};
828 unsigned unsignedIndex = index.getValue().getZExtValue();
830 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
833 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
834 sliceOp.isDynamicSize(unsignedIndex)) {
835 return {sliceOp.getDynamicSize(unsignedIndex)};
851 LogicalResult matchAndRewrite(DimOp dimOp,
853 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
856 Value newSource = castOp.getOperand();
867 LogicalResult matchAndRewrite(DimOp dimOp,
869 auto source = dimOp.getSource();
870 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
874 auto resultIndex = cast<OpResult>(source).getResultNumber();
875 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
878 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
888 LogicalResult matchAndRewrite(DimOp dim,
890 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
900 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
901 if (extract.
getType() != dim.getType())
903 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
912 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
922 assert(all_of(staticShape,
923 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
924 "expected only static sizes");
925 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
932 build(builder, result, tensorType, dynamicSizes);
941 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
946 return emitOpError(
"incorrect number of dynamic sizes, has ")
948 <<
getType().getNumDynamicDims();
957 for (int64_t i = 0; i <
getType().getRank(); ++i) {
958 if (
getType().isDynamicDim(i)) {
967 Value EmptyOp::getDynamicSize(
unsigned idx) {
968 assert(
getType().isDynamicDim(idx) &&
"expected dynamic dim");
970 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
980 for (int64_t i = 0; i <
getType().getRank(); ++i) {
981 if (
getType().isDynamicDim(i)) {
1005 LogicalResult matchAndRewrite(EmptyOp op,
1009 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
1012 if (foldedTensorType == op.getType())
1015 auto newOp = rewriter.
create<EmptyOp>(op.getLoc(), foldedTensorType,
1016 foldedDynamicSizes);
1025 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
1027 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
1028 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
1029 if (!emptyTensorOp || !maybeConstantIndex)
1031 auto emptyTensorType = emptyTensorOp.getType();
1032 if (*maybeConstantIndex < 0 ||
1033 *maybeConstantIndex >= emptyTensorType.getRank() ||
1034 !emptyTensorType.isDynamicDim(*maybeConstantIndex))
1037 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1060 LogicalResult matchAndRewrite(CastOp castOp,
1064 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1069 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1073 newMixedSizes.reserve(currMixedSizes.size());
1074 assert(resultShape.size() == currMixedSizes.size() &&
1075 "mismatch in result shape and sizes of empty op");
1076 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1077 int64_t newDim = std::get<0>(it);
1081 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1082 if (ShapedType::isDynamic(newDim) ||
1083 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1088 producer,
"mismatch in static value of shape of empty tensor "
1089 "result and cast result");
1091 newMixedSizes.push_back(attr);
1097 if (!ShapedType::isDynamic(newDim)) {
1098 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1104 newMixedSizes.push_back(currDim);
1109 resultType.getElementType());
1118 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1119 ReplaceEmptyTensorStaticShapeDims>(context);
1128 std::optional<Attribute> cst = std::nullopt) {
1129 if (source && source.
isSplat() && result.hasStaticShape() &&
1150 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1153 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1155 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1158 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1161 extract, tensorCast.getSource(), extract.getIndices());
1168 void ExtractOp::getAsmResultNames(
1170 setNameFn(getResult(),
"extracted");
1175 auto tensorType = llvm::cast<RankedTensorType>(getTensor().
getType());
1176 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1177 return emitOpError(
"incorrect number of indices for extract_element");
1182 if (
Attribute tensor = adaptor.getTensor()) {
1185 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1186 return splatTensor.getSplatValue<
Attribute>();
1189 if (isa<DenseResourceElementsAttr>(tensor))
1195 for (
Attribute indice : adaptor.getIndices()) {
1196 if (!indice || !llvm::isa<IntegerAttr>(indice))
1198 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1202 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1203 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1204 auto rank = tensorType.getRank();
1205 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1209 for (
int i = rank - 1; i >= 0; --i) {
1210 flatIndex += indices[i] * stride;
1211 stride *= tensorType.getDimSize(i);
1215 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1218 return fromElementsOp.getElements()[flatIndex];
1222 if (
Attribute tensor = adaptor.getTensor()) {
1223 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1224 if (elementsAttr && elementsAttr.isValidIndex(indices))
1225 return elementsAttr.getValues<
Attribute>()[indices];
1233 results.
add<ExtractFromTensorCast>(context);
1240 void FromElementsOp::getAsmResultNames(
1242 setNameFn(getResult(),
"from_elements");
1247 assert(!elements.empty() &&
"expected at least one element");
1249 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1250 build(builder, result, resultType, elements);
1253 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1254 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1277 struct ExtractElementFromIndexCast
1281 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1284 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1290 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1291 loc, elementTy, indexCast.getIn(), extract.getIndices());
1304 results.
add<ExtractElementFromIndexCast>(context);
1311 void GatherOp::getAsmResultNames(
1313 setNameFn(getResult(),
"gather");
1328 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1329 RankedTensorType indicesType,
1333 resultShape.reserve(resultShape.size() + sourceType.getRank());
1334 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1335 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1337 resultShape.push_back(1);
1340 resultShape.push_back(sourceType.getDimSize(idx));
1345 static LogicalResult
1348 StringRef gatherOrScatter, StringRef sourceOrDest) {
1350 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1352 int64_t numGatherDims = dims.size();
1353 if (numGatherDims > rank)
1355 <<
"_dims overflow " << sourceOrDest <<
" rank";
1356 if (indices.empty() || indices.back() != numGatherDims)
1358 <<
"_dims length must match the size of last dimension of indices";
1359 for (int64_t val : dims) {
1362 <<
"_dims value must be non-negative";
1365 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1367 for (int64_t i = 1; i < numGatherDims; ++i) {
1368 if (dims[i - 1] >= dims[i])
1370 <<
"_dims values must be strictly increasing";
1376 int64_t sourceRank = getSourceType().getRank();
1379 getIndicesType().
getShape(), sourceRank,
1380 "gather",
"source")))
1383 RankedTensorType expectedResultType = GatherOp::inferResultType(
1384 getSourceType(), getIndicesType(), gatherDims,
false);
1385 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1386 getSourceType(), getIndicesType(), gatherDims,
true);
1387 if (getResultType() != expectedResultType &&
1388 getResultType() != expectedRankReducedResultType) {
1389 return emitOpError(
"result type "
1392 << expectedResultType <<
" or its rank-reduced variant "
1393 << expectedRankReducedResultType <<
" (got: " << getResultType()
1402 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1404 return reshapedSource;
1412 void InsertOp::getAsmResultNames(
1414 setNameFn(getResult(),
"inserted");
1419 auto destType = llvm::cast<RankedTensorType>(getDest().
getType());
1420 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1421 return emitOpError(
"incorrect number of indices");
1429 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1430 if (scalar == splatDest.getSplatValue<
Attribute>())
1439 void GenerateOp::getAsmResultNames(
1441 setNameFn(getResult(),
"generated");
1448 for (
auto dim : llvm::seq<int64_t>(0,
getType().getRank())) {
1449 if (
getType().isDynamicDim(dim)) {
1450 reifiedReturnShapes[0][dim] = getOperand(idx++);
1452 reifiedReturnShapes[0][dim] =
1462 RankedTensorType resultType = llvm::cast<RankedTensorType>(
getType());
1463 if (getNumOperands() != resultType.getNumDynamicDims())
1464 return emitError(
"must have as many index operands as dynamic extents "
1465 "in the result type");
1469 LogicalResult GenerateOp::verifyRegions() {
1470 RankedTensorType resultTy = llvm::cast<RankedTensorType>(
getType());
1472 if (!llvm::all_of(getBody().getArgumentTypes(),
1474 return emitError(
"all body arguments must be index");
1475 if (getBody().getNumArguments() != resultTy.getRank())
1476 return emitError(
"must have one body argument per input dimension");
1479 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1481 if (yieldOp.getValue().getType() != resultTy.getElementType())
1483 "body must be terminated with a `yield` operation of the tensor "
1489 void GenerateOp::build(
1493 build(b, result, resultTy, dynamicExtents);
1498 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1502 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1515 LogicalResult matchAndRewrite(GenerateOp generateOp,
1519 generateOp.getType(), generateOp.getDynamicExtents(),
1520 foldedDynamicSizes);
1523 if (foldedTensorType == generateOp.getType())
1526 auto loc = generateOp.getLoc();
1528 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1530 newOp.getBody().begin());
1532 generateOp.getType(), newOp);
1548 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1551 LogicalResult matchAndRewrite(tensor::ExtractOp extract,
1553 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1558 Block *body = &tensorFromElements.getBody().
front();
1561 rewriter.
clone(op, mapping);
1575 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1582 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1583 setNameFn(getResult(),
"rank");
1588 auto type = getOperand().getType();
1589 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1590 if (shapedType && shapedType.hasRank())
1592 return IntegerAttr();
1599 void ReshapeOp::getAsmResultNames(
1601 setNameFn(getResult(),
"reshape");
1605 int64_t numElements = 1;
1606 for (
auto dim : type.getShape())
1616 return emitOpError(
"element types of source and destination tensor "
1617 "types should be the same");
1621 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1622 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1624 if (resultRankedType) {
1625 if (operandRankedType && resultRankedType.hasStaticShape() &&
1626 operandRankedType.hasStaticShape()) {
1628 return emitOpError(
"source and destination tensor should have the "
1629 "same number of elements");
1631 if (ShapedType::isDynamic(shapeSize))
1632 return emitOpError(
"cannot use shape operand with dynamic length to "
1633 "reshape to statically-ranked tensor type");
1634 if (shapeSize != resultRankedType.getRank())
1636 "length of shape operand differs from the result's tensor rank");
1643 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1645 return reshapedSource;
1650 if (
auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
1651 getSourceMutable().assign(reshapeOpProducer.getSource());
1655 auto source = getSource();
1656 auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1657 auto resultTy = dyn_cast<RankedTensorType>(
getType());
1658 if (!sourceTy || !resultTy || sourceTy != resultTy)
1663 if (sourceTy.getRank() == 1)
1666 if (
auto fromElements =
getShape().getDefiningOp<tensor::FromElementsOp>()) {
1667 auto elements = fromElements.getElements();
1669 sourceTy.getRank() ==
static_cast<int64_t
>(elements.size());
1670 for (
int id = 0, s = elements.size();
id < s && dynamicNoop; ++id) {
1671 auto element = elements[id];
1674 dynamicNoop &= cst.value() == sourceTy.getDimSize(
id);
1678 if (
auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1679 dynamicNoop &= dimOp.getSource() == source;
1684 cst.has_value() && cst.value() ==
static_cast<int64_t
>(id);
1688 dynamicNoop =
false;
1703 void CollapseShapeOp::getAsmResultNames(
1705 setNameFn(getResult(),
"collapsed");
1708 void ExpandShapeOp::getAsmResultNames(
1710 setNameFn(getResult(),
"expanded");
1713 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1714 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1715 "invalid resultDim");
1717 if (llvm::is_contained(it.value(), resultDim))
1719 llvm_unreachable(
"could not find reassociation group");
1722 FailureOr<SmallVector<OpFoldResult>>
1724 RankedTensorType expandedType,
1727 std::optional<SmallVector<OpFoldResult>> outputShape =
1732 return *outputShape;
1739 auto [staticOutputShape, dynamicOutputShape] =
1741 build(builder, result, cast<RankedTensorType>(resultType), src,
1743 dynamicOutputShape, staticOutputShape);
1751 auto tensorResultTy = cast<RankedTensorType>(resultType);
1752 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
1753 builder, result.
location, tensorResultTy, reassociation, inputShape);
1755 if (succeeded(outputShape)) {
1756 outputShapeOrEmpty = *outputShape;
1758 build(builder, result, tensorResultTy, src, reassociation,
1759 outputShapeOrEmpty);
1767 getReassociationIndices());
1775 getReassociationIndices());
1778 RankedTensorType CollapseShapeOp::inferCollapsedType(
1780 return inferCollapsedType(
1782 type.getContext(), reassociation)));
1788 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1790 auto shape = type.getShape();
1792 newShape.reserve(reassociation.size());
1797 unsigned currentDim = 0;
1799 unsigned dim = m.getNumResults();
1800 auto band = shape.slice(currentDim, dim);
1802 if (llvm::is_contained(band, ShapedType::kDynamic))
1803 size = ShapedType::kDynamic;
1805 for (
unsigned d = 0; d < dim; ++d)
1806 size *= shape[currentDim + d];
1807 newShape.push_back(size);
1817 auto resultType = inferCollapsedType(
1818 llvm::cast<RankedTensorType>(src.
getType()),
1823 build(b, result, resultType, src, attrs);
1826 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1827 TensorReshapeOp, ExpandShapeOp>::value>
1829 RankedTensorType expandedType,
1830 RankedTensorType collapsedType) {
1835 auto maps = op.getReassociationMaps();
1836 RankedTensorType expectedType =
1837 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1839 return op.emitOpError(
"expected collapsed type to be ")
1840 << expectedType <<
", but got " << collapsedType;
1845 auto srcType = getSrcType();
1846 auto resultType = getResultType();
1848 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
1849 return emitOpError(
"expected number of static shape dims to be equal to "
1850 "the output rank (")
1851 << resultType.getRank() <<
") but found "
1852 << getStaticOutputShape().size() <<
" inputs instead";
1854 if ((int64_t)getOutputShape().size() !=
1855 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
1856 return emitOpError(
"mismatch in dynamic dims in output_shape and "
1857 "static_output_shape: static_output_shape has ")
1858 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
1859 <<
" dynamic dims while output_shape has " << getOutputShape().size()
1872 template <
typename TensorReshapeOp>
1875 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1883 reshapeOp.getResultType(), attr.
getRawData());
1890 template <
typename TensorReshapeOp>
1895 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1897 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1898 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1902 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1909 template <
typename TensorReshapeOp>
1912 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
1915 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1919 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1921 if (!shapedTy.hasStaticShape())
1925 fromElements.getElements());
1934 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1936 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1940 RankedTensorType srcType =
1941 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1942 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1943 srcType, collapseShapeOp.getReassociationMaps());
1945 if (newResultType == collapseShapeOp.getResultType()) {
1947 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1950 auto newOp = rewriter.
create<CollapseShapeOp>(
1951 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1952 collapseShapeOp.getReassociation());
1954 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1963 LogicalResult matchAndRewrite(DimOp dimOp,
1965 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1970 std::optional<int64_t> dim = dimOp.getConstantIndex();
1971 if (!dim.has_value())
1975 RankedTensorType resultType = expandShapeOp.getResultType();
1976 if (!resultType.isDynamicDim(*dim))
1980 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1986 for (int64_t d : grp) {
1988 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1989 product *= resultType.getDimSize(d);
1995 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1999 dimOp, expr.floorDiv(
product), srcDimSz);
2007 LogicalResult matchAndRewrite(DimOp dimOp,
2009 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
2010 if (!collapseShapeOp)
2014 std::optional<int64_t> dim = dimOp.getConstantIndex();
2015 if (!dim.has_value())
2019 RankedTensorType resultType = collapseShapeOp.getResultType();
2020 if (!resultType.isDynamicDim(*dim))
2025 collapseShapeOp.getReassociationIndices()[*dim];
2032 srcDimSizes.push_back(rewriter.
create<DimOp>(
2033 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
2047 struct ConvertToStaticExpandShape :
public OpRewritePattern<ExpandShapeOp> {
2050 LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
2052 auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
2058 expandOp.getReassociationIndices();
2062 auto outputIt = expandOp.getOutputShape().begin();
2064 for (
const auto &[inputDim, innerReassoc] :
llvm::enumerate(reassoc)) {
2065 for (uint64_t outDim : innerReassoc) {
2066 if (!ShapedType::isDynamic(newOutputShape[outDim]))
2073 Value val = *outputIt;
2075 if (ShapedType::isDynamic(castSrcShape[inputDim])) {
2076 dynamicOutputShape.push_back(val);
2082 newOutputShape[outDim] = cst.getSExtValue();
2084 dynamicOutputShape.push_back(val);
2090 if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
2095 for (
auto inDim : llvm::seq<int>(0, newInputShape.size())) {
2096 for (
auto outDim : reassoc[inDim]) {
2097 auto ofr = newOutputShape[outDim];
2098 if (ShapedType::isDynamic(ofr)) {
2099 newInputShape[inDim] = ShapedType::kDynamic;
2102 newInputShape[inDim] *= ofr;
2109 newInputShape, expandOp.getSrcType().getElementType());
2111 newOutputShape, expandOp.getSrcType().getElementType());
2112 auto inputCast = rewriter.
create<CastOp>(expandOp.getLoc(), inputType,
2114 auto newExpand = rewriter.
create<ExpandShapeOp>(
2115 expandOp.getLoc(), outputType, inputCast.getResult(),
2116 expandOp.getReassociationIndices(), outputOfr);
2118 newExpand.getResult());
2129 ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
2130 FoldReshapeWithSplat<ExpandShapeOp>,
2131 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
2132 FoldDimOfCollapseShape>(context);
2140 tensor::DimOp, RankedTensorType>,
2141 FoldReshapeWithConstant<CollapseShapeOp>,
2142 FoldReshapeWithSplat<CollapseShapeOp>,
2143 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
2147 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2148 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
2149 adaptor.getOperands());
2152 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2153 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
2154 adaptor.getOperands());
2161 void ExtractSliceOp::getAsmResultNames(
2163 setNameFn(getResult(),
"extracted_slice");
2169 RankedTensorType ExtractSliceOp::inferResultType(
2175 assert(
static_cast<int64_t
>(staticSizes.size()) ==
2176 sourceTensorType.getRank() &&
2177 "unexpected staticSizes not equal to rank of source");
2179 sourceTensorType.getEncoding());
2182 RankedTensorType ExtractSliceOp::inferResultType(
2190 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
2191 staticSizes, staticStrides);
2202 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2203 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2207 auto inferredType = llvm::cast<RankedTensorType>(
2208 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
2209 int rankDiff = inferredType.getRank() - desiredResultRank;
2211 auto shape = inferredType.getShape();
2212 llvm::SmallBitVector dimsToProject =
2216 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
2217 if (!dimsToProject.test(pos))
2218 projectedShape.push_back(shape[pos]);
2222 return inferredType;
2225 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
2226 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
2234 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2235 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
2242 RankedTensorType resultType,
Value source,
2252 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2255 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2256 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2259 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2272 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2281 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2287 RankedTensorType resultType,
Value source,
2296 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2303 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2308 RankedTensorType expectedType) {
2313 return op->
emitError(
"expected rank to be smaller or equal to ")
2314 <<
"the other rank. ";
2316 return op->
emitError(
"expected type to be ")
2317 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2319 return op->
emitError(
"expected element type to be ")
2320 << expectedType.getElementType();
2322 llvm_unreachable(
"unexpected extract_slice op verification result");
2329 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2330 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2342 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2343 assert(sourceTensorType &&
"not a ranked tensor type");
2344 auto sourceShape = sourceTensorType.getShape();
2345 if (sourceShape.equals(desiredShape))
2347 auto maybeRankReductionMask =
2349 if (!maybeRankReductionMask)
2358 reifiedReturnShapes.resize(1);
2359 reifiedReturnShapes[0].reserve(
getType().getRank());
2362 for (
const auto &size :
enumerate(mixedSizes)) {
2363 if (droppedDims.test(size.index()))
2365 reifiedReturnShapes[0].push_back(size.value());
2386 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2390 LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
2393 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2394 return matchPattern(operand, matchConstantIndex());
2398 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2407 Value newResult = rewriter.
create<ExtractSliceOp>(
2408 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2409 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2410 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2411 if (newResult.
getType() != sliceOp.getType())
2412 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
2421 template <
typename IterTy,
typename ElemTy>
2426 assert(offsets.size() == sizes.size());
2427 assert(offsets.size() == strides.size());
2428 if (offsets.empty())
2431 int64_t offset = offsets.front();
2432 int64_t size = sizes.front();
2433 int64_t stride = strides.front();
2434 if (offsets.size() == 1) {
2435 for (int64_t i = 0; i < size; ++i, offset += stride)
2436 outValues->push_back(*(values + offset));
2441 for (int64_t i = 0; i < size; ++i, offset += stride) {
2442 auto begin = values + offset * counts.front();
2443 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2444 offsets.drop_front(), sizes.drop_front(),
2445 strides.drop_front(), outValues);
2452 class ConstantOpExtractSliceFolder final
2457 ConstantOpExtractSliceFolder(
MLIRContext *context,
2460 controlFn(std::move(controlFn)) {}
2462 LogicalResult matchAndRewrite(ExtractSliceOp op,
2473 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2474 auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
2475 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2482 int64_t count = sourceType.getNumElements();
2487 auto offsets = op.getStaticOffsets();
2488 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2490 auto sizes = op.getStaticSizes();
2491 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2493 auto strides = op.getStaticStrides();
2494 if (llvm::is_contained(strides, ShapedType::kDynamic))
2500 counts.reserve(shape.size());
2501 for (int64_t v : shape) {
2503 counts.push_back(count);
2509 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2511 outValues.reserve(sourceType.getNumElements());
2512 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2513 elems.begin(), counts, offsets, sizes, strides, &outValues);
2515 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2517 outValues.reserve(sourceType.getNumElements());
2518 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2519 elems.begin(), counts, offsets, sizes, strides, &outValues);
2542 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2551 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2552 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2560 ExtractSliceOp newOp) {
2561 Value replacement = newOp.getResult();
2562 if (replacement.
getType() != op.getType())
2563 replacement = rewriter.
create<tensor::CastOp>(op.getLoc(), op.getType(),
2574 ExtractSliceOpCastFolder>(context);
2578 static LogicalResult
2580 ShapedType shapedType) {
2587 auto shape = shapedType.getShape();
2588 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2602 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2605 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2606 insertOp.isSameAs(extractOp, isSame))
2607 return insertOp.getSource();
2612 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2614 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2616 return reshapedSource;
2617 if (getSourceType() ==
getType() &&
2619 return this->getSource();
2628 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2629 unsigned rank = rankedTensorType.getRank();
2633 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2634 offsets, sizes, strides);
2641 void InsertSliceOp::getAsmResultNames(
2643 setNameFn(getResult(),
"inserted_slice");
2658 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2670 build(b, result, source, dest, offsets, sizes, strides, attrs);
2683 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2689 RankedTensorType srcType, RankedTensorType dstType,
2694 RankedTensorType expected = ExtractSliceOp::inferResultType(
2695 dstType, staticOffsets, staticSizes, staticStrides);
2697 *expectedType = expected;
2703 RankedTensorType expectedType;
2706 getStaticSizes(), getStaticStrides(), &expectedType);
2728 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2731 if (!prevInsertOp ||
2732 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2733 !prevInsertOp.isSameAs(insertOp, isSame))
2736 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2748 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2751 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2752 !extractOp.isSameAs(insertOp, isSame))
2755 return extractOp.getSource();
2759 if (getSourceType().hasStaticShape() &&
getType().hasStaticShape() &&
2760 getSourceType() ==
getType() &&
2762 return this->getSource();
2784 template <
typename InsertOpTy>
2785 class InsertSliceOpConstantArgumentFolder final
2790 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2803 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2804 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2805 mixedOffsets, mixedSizes, mixedStrides);
2806 Value toInsert = insertSliceOp.getSource();
2807 if (sourceType != insertSliceOp.getSourceType()) {
2812 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2814 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2815 sourceType, toInsert);
2818 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2819 mixedSizes, mixedStrides);
2844 template <
typename InsertOpTy>
2845 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2848 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2850 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2851 return matchPattern(operand, matchConstantIndex());
2855 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2856 auto castOp = v.getDefiningOp<tensor::CastOp>();
2858 return std::nullopt;
2859 return castOp.getSource();
2861 std::optional<Value> sourceCastSource =
2862 getSourceOfCastOp(insertSliceOp.getSource());
2863 std::optional<Value> destCastSource =
2864 getSourceOfCastOp(insertSliceOp.getDest());
2865 if (!sourceCastSource && !destCastSource)
2869 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2870 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2871 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2872 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2873 if (!srcType || !dstType)
2881 staticSizes, srcType.getShape(),
true);
2882 if (!rankReductionMask.has_value())
2890 int64_t rankReducedIdx = 0;
2891 for (
auto [idx, size] :
enumerate(staticSizes)) {
2892 if (!rankReductionMask.value().contains(idx) &&
2893 !srcType.isDynamicDim(rankReducedIdx)) {
2895 rewriter.
getContext(), srcType.getDimSize(rankReducedIdx));
2896 size = srcType.getDimSize(rankReducedIdx++);
2900 staticSizes, insertSliceOp.getStaticStrides()) !=
2905 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2906 mixedSizes, insertSliceOp.getMixedStrides());
2909 bool isParallelInsert =
2910 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2911 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2912 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2913 insertSliceOp.getDestType(),
2942 template <
typename InsertOpTy>
2943 struct InsertSliceOpSourceCastInserter final
2947 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
2949 RankedTensorType srcType = insertSliceOp.getSourceType();
2950 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2953 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2954 if (std::optional<int64_t> constInt =
2959 newSrcShape[i] = *constInt;
2966 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2967 if (srcType == newSrcType ||
2969 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2981 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2984 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2986 insertSliceOp, cast, insertSliceOp.getDest(),
2987 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2988 insertSliceOp.getMixedStrides());
3000 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
3001 InsertSliceOpCastFolder<InsertSliceOp>,
3002 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
3009 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
3010 unsigned rank = rankedTensorType.getRank();
3014 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
3023 setNameFn(getResult(),
"padded");
3029 Type typeToInfer,
Type typeToInferFrom) {}
3033 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
3034 Type &typeToInfer,
Type typeToInferFrom) {
3036 typeToInfer = typeToInferFrom;
3041 auto sourceType = llvm::cast<RankedTensorType>(getSource().
getType());
3042 auto resultType = llvm::cast<RankedTensorType>(getResult().
getType());
3044 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
3045 if (!expectedType) {
3046 return emitError(
"failed to infer expectedType from sourceType ")
3047 << sourceType <<
", specified resultType is " << resultType;
3049 if (resultType.getRank() != expectedType.getRank()) {
3051 << resultType <<
" does not match the inferred type "
3054 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
3055 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
3057 if (expectedType.isDynamicDim(i))
3060 << resultType <<
" does not match the inferred type "
3067 LogicalResult PadOp::verifyRegions() {
3068 auto ®ion = getRegion();
3069 unsigned rank = llvm::cast<RankedTensorType>(getResult().
getType()).getRank();
3072 return emitError(
"expected the block to have ") << rank <<
" arguments";
3076 if (!en.value().isIndex())
3077 return emitOpError(
"expected block argument ")
3078 << (en.index() + 1) <<
" to be an index";
3083 if (yieldOp.getValue().getType() !=
3085 return emitOpError(
"expected yield type to match shape element type");
3090 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
3094 unsigned rank = sourceType.getRank();
3095 if (staticLow.size() != rank)
3096 return RankedTensorType();
3097 if (staticHigh.size() != rank)
3098 return RankedTensorType();
3099 if (!resultShape.empty() && resultShape.size() != rank)
3100 return RankedTensorType();
3103 for (
auto i : llvm::seq<unsigned>(0, rank)) {
3104 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
3105 staticHigh[i] == ShapedType::kDynamic) {
3106 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
3109 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
3110 assert((resultShape.empty() || size == resultShape[i] ||
3111 resultShape[i] == ShapedType::kDynamic) &&
3112 "mismatch between inferred shape and result shape");
3113 inferredShape.push_back(size);
3124 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3126 resultType = inferResultType(sourceType, staticLow, staticHigh);
3128 build(b, result, resultType, source, low, high,
3136 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3137 unsigned rank = sourceType.getRank();
3139 build(b, result, resultType, source, staticVector, staticVector, low, high,
3147 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
3157 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
3159 assert(llvm::isa<RankedTensorType>(resultType));
3161 build(b, result, resultType, source, dynamicLow, dynamicHigh,
3170 build(b, result, resultType, source, low, high, nofold, attrs);
3174 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
3181 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
3185 llvm::SmallBitVector PadOp::getPaddedDims() {
3186 llvm::SmallBitVector paddedDims(getSourceType().getRank());
3188 for (
const auto &en :
enumerate(paddingWidths))
3190 paddedDims.set(en.index());
3192 extractPaddedDims(getMixedLowPad());
3193 extractPaddedDims(getMixedHighPad());
3203 LogicalResult matchAndRewrite(PadOp padTensorOp,
3205 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
3207 if (padTensorOp.getNofold())
3210 padTensorOp, padTensorOp.getResult().getType(),
3211 padTensorOp.getSource());
3220 LogicalResult matchAndRewrite(PadOp padTensorOp,
3222 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
3226 auto newResultType = PadOp::inferResultType(
3227 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
3228 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3229 padTensorOp.getResultType().getShape());
3231 if (newResultType == padTensorOp.getResultType()) {
3233 padTensorOp.getSourceMutable().assign(castOp.getSource());
3236 auto newOp = rewriter.
create<PadOp>(
3237 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
3238 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
3239 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3242 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3245 padTensorOp, padTensorOp.getResultType(), newOp);
3256 LogicalResult matchAndRewrite(PadOp padTensorOp,
3258 if (!padTensorOp.getResult().hasOneUse())
3261 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
3265 tensorCastOp.getDest().getType()))
3268 auto replacementOp = rewriter.
create<PadOp>(
3269 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
3270 padTensorOp.getSource(), padTensorOp.getStaticLow(),
3271 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
3272 padTensorOp.getHigh(), padTensorOp.getNofold(),
3276 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
3277 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3320 LogicalResult matchAndRewrite(PadOp padOp,
3322 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3325 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3326 if (!outerPadOp || outerPadOp.getNofold())
3328 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3333 int64_t rank = padOp.getSourceType().getRank();
3334 if (outerSliceOp.getSourceType().getRank() != rank) {
3336 "cannot fold rank-reducing chain");
3340 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3342 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3346 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3348 "cannot fold PadOps with low padding");
3353 Value innerValue = padOp.getConstantPaddingValue();
3354 Value outerValue = outerPadOp.getConstantPaddingValue();
3355 if (!innerValue || !outerValue ||
3358 innerAttr != outerAttr) {
3360 padOp,
"cannot fold PadOps with different padding values");
3364 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3365 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3366 if (innerDims.anyCommon(outerDims)) {
3368 padOp,
"cannot fold PadOps with common padding dimensions");
3378 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3379 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3380 if (!innerDims.test(en.index()) &&
3382 en.value() = outerOffset;
3385 if (!outerDims.test(en.index()) &&
3387 en.value() = innerOffset;
3391 padOp,
"cannot find zero-offset and zero-padding pair");
3401 if (!outerDims.test(en.index()))
3403 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3404 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3405 assert(!ShapedType::isDynamic(sourceSize) &&
3406 "expected padded dimension to have a static size");
3409 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3410 "match the size of the outer padding");
3412 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3418 if (innerDims.test(en.index()))
3419 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3420 if (outerDims.test(en.index()))
3421 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3426 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3427 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3428 innerSliceOp.getMixedStrides());
3429 auto newPadOp = rewriter.
create<PadOp>(
3430 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3431 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3434 newPadOp.getRegion().begin());
3435 rewriter.
replaceOp(padOp, newPadOp.getResult());
3443 LogicalResult matchAndRewrite(PadOp padTensorOp,
3445 Value input = padTensorOp.getSource();
3446 if (!llvm::isa<RankedTensorType>(input.
getType()))
3448 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3449 auto inputRank = inputDims.size();
3451 auto oldResultType =
3452 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3456 auto outputDims = oldResultType.getShape();
3461 for (
auto operand : padTensorOp.getLow()) {
3464 constOperandsLow.push_back(ShapedType::kDynamic);
3465 newLows.push_back(operand);
3468 constOperandsLow.push_back(intOp.getExtValue());
3472 for (
auto operand : padTensorOp.getHigh()) {
3475 constOperandsHigh.push_back(ShapedType::kDynamic);
3476 newHighs.push_back(operand);
3479 constOperandsHigh.push_back(intOp.getExtValue());
3486 if (inputDims.size() != outputDims.size() ||
3487 inputDims.size() != constLow.size() ||
3488 inputDims.size() != constHigh.size())
3493 for (
size_t i = 0; i < inputRank; i++) {
3494 if (constLow[i] == ShapedType::kDynamic)
3495 constLow[i] = constOperandsLow[lowCount++];
3496 if (constHigh[i] == ShapedType::kDynamic)
3497 constHigh[i] = constOperandsHigh[highCount++];
3505 for (
size_t i = 0; i < inputRank; i++) {
3506 if (outputDims[i] == ShapedType::kDynamic) {
3507 newOutDims.push_back(
3508 (staticLow[i] == ShapedType::kDynamic ||
3509 staticHigh[i] == ShapedType::kDynamic ||
3510 inputDims[i] == ShapedType::kDynamic
3511 ? ShapedType::kDynamic
3512 : inputDims[i] + staticLow[i] + staticHigh[i]));
3514 newOutDims.push_back(outputDims[i]);
3519 llvm::all_of(newOutDims,
3520 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3525 newOutDims, padTensorOp.getType().getElementType());
3526 auto newOp = rewriter.
create<PadOp>(
3527 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3528 newLows, newHighs, padTensorOp.getNofold(),
3532 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3560 struct FoldConsecutiveConstantPadding :
public OpRewritePattern<tensor::PadOp> {
3563 LogicalResult matchAndRewrite(tensor::PadOp padOp,
3565 if (padOp.getNofold()) {
3569 auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
3570 if (!producerPad || producerPad.getNofold()) {
3572 padOp,
"producer is not a foldable tensor.pad op");
3576 Value consumerPadValue = padOp.getConstantPaddingValue();
3577 Value producerPadValue = producerPad.getConstantPaddingValue();
3578 if (!consumerPadValue || !producerPadValue ||
3579 consumerPadValue != producerPadValue) {
3582 "cannot fold PadOps with different or non-constant padding values");
3593 for (
auto [consumerIndex, producerIndex] :
3594 llvm::zip_equal(consumerPaddings, producerPaddings)) {
3596 rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
3602 addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
3604 addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
3606 auto newPadOp = rewriter.
create<tensor::PadOp>(
3607 padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
3608 newLowPad, newHighPad, padOp.getNofold(),
3611 newPadOp.getRegion().begin());
3612 rewriter.
replaceOp(padOp, newPadOp.getResult());
3621 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3622 FoldOrthogonalPaddings, FoldStaticPadding,
3623 FoldConsecutiveConstantPadding>(context);
3635 Value PadOp::getConstantPaddingValue() {
3636 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3639 Value padValue = yieldOp.getValue();
3651 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3661 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3662 ParallelCombiningOpInterface parallelCombiningParent =
3663 getParallelCombiningParent();
3664 for (
const auto &it :
3667 if (&nextOp == getOperation())
3668 return parallelCombiningParent.getParentResult(it.index());
3670 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3686 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3699 build(b, result, source, dest, offsets, sizes, strides, attrs);
3713 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3717 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3718 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3719 << *(getOperation()->getParentOp());
3721 RankedTensorType expectedType;
3724 getStaticSizes(), getStaticStrides(), &expectedType);
3728 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3730 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3731 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3732 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3743 void ScatterOp::getAsmResultNames(
3745 setNameFn(getResult(),
"scatter");
3749 int64_t destRank = getDestType().getRank();
3752 getIndicesType().
getShape(), destRank,
3753 "scatter",
"dest")))
3757 return emitOpError(
"requires 'unique' attribute to be set");
3764 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3765 getDestType(), getIndicesType(), scatterDims,
false);
3766 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3767 getDestType(), getIndicesType(), scatterDims,
true);
3768 if (getSourceType() != expectedSourceType &&
3769 getSourceType() != expectedRankReducedSourceType) {
3770 return emitOpError(
"source type "
3773 << expectedSourceType <<
" or its rank-reduced variant "
3774 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3787 build(builder, result, aggregateType, element, dynamicSizes);
3793 build(builder, result, aggregateType, element, dynamicSizes);
3801 build(builder, result, element, staticShape, dynamicSizes);
3804 void SplatOp::getAsmResultNames(
3806 setNameFn(getResult(),
"splat");
3811 return emitOpError(
"incorrect number of dynamic sizes, has ")
3813 <<
getType().getNumDynamicDims();
3822 for (int64_t i = 0; i <
getType().getRank(); ++i) {
3823 if (
getType().isDynamicDim(i)) {
3833 auto constOperand = adaptor.getInput();
3834 if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
3838 if (!
getType().hasStaticShape())
3850 template <
typename OpTy>
3851 static LogicalResult
3854 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3855 "applies to only pack or unpack operations");
3856 int64_t destRank = op.getDestRank();
3858 reifiedReturnShapes[0] =
3863 template <
typename OpTy>
3865 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3866 "applies to only pack or unpack operations");
3870 assert(tiles.size() == dimsToTile.size() &&
3871 "tiles must match indices of dimension to block");
3873 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3874 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3875 return dimAndTileMapping;
3878 template <
typename OpTy>
3880 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3881 "applies to only pack or unpack operations");
3884 unsigned dynamicValIndex = 0;
3885 for (int64_t staticTile : op.getStaticInnerTiles()) {
3886 if (!ShapedType::isDynamic(staticTile))
3889 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3891 return mixedInnerTiles;
3894 template <
typename OpTy>
3896 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3897 "applies to only pack or unpack operations");
3910 size_t dimsPosSize = dimsPos.size();
3911 if (dimsPosSize > rank)
3914 for (int64_t dim : dimsPos)
3915 uniqued.insert(dim);
3916 if (dimsPosSize != uniqued.size())
3918 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3919 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3928 sourceShape.size() == limitShape.size() &&
3929 "expected source shape rank, and limit of the shape to have same rank");
3930 return llvm::all_of(
3931 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3932 int64_t sourceExtent = std::get<0>(it);
3933 int64_t limit = std::get<1>(it);
3934 return ShapedType::isDynamic(sourceExtent) ||
3935 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3939 template <
typename OpTy>
3941 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3942 "applies to only pack or unpack operations");
3943 Operation *op = packOrUnPack.getOperation();
3947 return llvm::any_of(
3953 if (hasZeros(mixedTiles))
3954 return op->
emitError(
"invalid zero tile factor");
3957 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3958 ? packOrUnPack.getSourceType()
3959 : packOrUnPack.getDestType();
3960 size_t unpackedRank = unpackedType.getRank();
3964 return op->
emitError(
"invalid inner_dims_pos vector");
3966 return op->
emitError(
"invalid outer_dims_perm vector");
3967 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3968 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3972 if (mixedTiles.size() > unpackedRank) {
3973 return op->
emitError(
"tiling factors must be less than or equal to the "
3974 "input rank for pack or output rank for unpack");
3976 if (mixedTiles.size() != innerDimsPos.size()) {
3978 "tiling factors must equal the number of dimensions to tile");
3981 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3982 ? packOrUnPack.getDestType()
3983 : packOrUnPack.getSourceType();
3984 size_t packedRank = packedType.getRank();
3986 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
3987 if (expectedPackedRank != packedRank) {
3989 "packed rank != (unpacked rank + num tiling factors), got ")
3990 << packedRank <<
" != " << expectedPackedRank;
3996 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3997 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3998 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3999 return op->
emitError(
"the shape of output is not large enough to hold the "
4000 "packed data. Expected at least ")
4001 << expectedPackedType <<
", got " << packedType;
4004 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4006 [](std::tuple<int64_t, OpFoldResult> it) {
4007 int64_t shape = std::get<0>(it);
4008 if (Attribute attr =
4009 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4010 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4011 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4012 return shape == staticTileSize;
4014 return ShapedType::isDynamic(shape);
4016 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4017 "tiled dimension in the packed type");
4029 struct PackOrUnPackTransposeResult {
4036 template <
typename OpTy>
4037 static PackOrUnPackTransposeResult
4041 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4042 "applies to only pack or unpack operations");
4043 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4044 "some permutation must be non-empty");
4045 PackOrUnPackTransposeResult metadata;
4046 metadata.innerDimsPos =
4048 metadata.innerTiles =
4050 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4051 ? packOrUnPackOp.getSourceRank()
4052 : packOrUnPackOp.getDestRank();
4053 metadata.outerDimsPerm =
4054 packOrUnPackOp.getOuterDimsPerm().empty()
4055 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4057 if (!innerPermutation.empty()) {
4058 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4060 "invalid inner permutation");
4064 if (!outerPermutation.empty()) {
4065 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4067 "invalid outer permutation");
4077 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4078 setNameFn(getResult(),
"pack");
4084 std::optional<Value> paddingValue,
4086 assert(innerDimsPos.size() == innerTiles.size() &&
4087 "number of tile sizes specified must match the specified number of "
4088 "original dimensions to be tiled");
4092 build(builder, state, dest.
getType(), source, dest,
4093 paddingValue ? *paddingValue :
nullptr,
4094 outerDimsPerm.empty() ?
nullptr
4119 ShapedType inputType = getSourceType();
4120 int64_t inputRank = inputType.getRank();
4121 return getDestType().getShape().take_front(inputRank);
4125 auto innerDimsPos = getInnerDimsPos();
4126 auto packedShape = getDestType().getShape();
4129 for (
auto index : innerDimsPos)
4130 res.push_back(packedShape[index]);
4141 outputShape.take_front(inputShape.size()));
4142 if (!outerDimsPerm.empty()) {
4143 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4144 "expected output and outer_dims_perm to have same size");
4148 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4149 if (ShapedType::isDynamic(inputShape[pos]))
4153 if (!constantTile) {
4154 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4155 (inputShape[pos] % outputTileSizes[pos] != 0))
4157 }
else if (inputShape[pos] % (*constantTile) != 0) {
4171 auto paddingValue = getPaddingValue();
4174 return emitOpError(
"expected padding_value has ")
4175 << getSourceType().getElementType()
4176 <<
" but got: " << paddingValue.getType();
4179 if (!paddingValue &&
4180 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4181 getDestType().
getShape(), getOuterDimsPerm(),
4184 "invalid tile factor or output size provided. Only full tiles are "
4185 "supported when padding_value is not set");
4195 for (
auto o : ofrs) {
4197 if (llvm::dyn_cast_if_present<Value>(o))
4198 result.push_back(ShapedType::kDynamic);
4212 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4213 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4215 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4216 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4219 resultShape[tiledDim.value()] = divideCeilSigned(
4220 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4224 if (!outerDimsPerm.empty())
4228 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4241 for (
auto tiledDim :
llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4243 builder, loc, ceilDivExpr,
4244 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4246 if (!outerDimsPerm.empty())
4248 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4253 innerDimsPos, outerDimsPerm);
4259 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4260 if (!ShapedType::isDynamic(resultTypeShape[i]))
4271 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4276 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4293 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4294 if (ShapedType::isDynamic(value))
4295 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
4299 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4300 int64_t dimPos = std::get<0>(it);
4302 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4304 if (!outerDimsPerm.empty())
4305 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4307 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4308 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4309 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4316 *
this, innerPermutation, outerPermutation);
4317 Value transposedDest =
4318 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4319 metadata.innerDimsPos, metadata.outerDimsPerm);
4320 return b.
create<PackOp>(loc, getSource(), transposedDest,
4321 metadata.innerDimsPos, metadata.innerTiles,
4322 getPaddingValue(), metadata.outerDimsPerm);
4326 template <
typename OpTy>
4328 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4329 "applies to only pack or unpack operations");
4330 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4332 : op.getSourceType();
4334 for (
auto [dimDest,
tile] : llvm::zip(
4335 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4337 if (!constTileSize || ShapedType::isDynamic(dimDest))
4344 if (getPaddingValue())
4359 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4361 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4373 auto packTiles = packOp.getMixedTiles();
4374 auto unPackTiles = unPackOp.getMixedTiles();
4375 if (packTiles.size() != unPackTiles.size())
4377 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4386 auto srcType = op.getSourceType();
4387 if (llvm::any_of(op.getInnerDimsPos(),
4388 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4390 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4392 return !PackOp::requirePaddingValue(
4393 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4394 op.getOuterDimsPerm(), op.getMixedTiles());
4401 bool changeNeeded =
false;
4402 srcShape.assign(packOp.getSourceType().getShape().begin(),
4403 packOp.getSourceType().getShape().end());
4404 destShape.assign(packOp.getDestType().getShape().begin(),
4405 packOp.getDestType().getShape().end());
4406 llvm::SmallSetVector<int64_t, 4> innerDims;
4407 innerDims.insert(packOp.getInnerDimsPos().begin(),
4408 packOp.getInnerDimsPos().end());
4410 if (!packOp.getOuterDimsPerm().empty())
4412 int srcRank = packOp.getSourceRank();
4413 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4414 if (innerDims.contains(i))
4417 int64_t destPos = i;
4418 if (!inverseOuterDimsPerm.empty())
4419 destPos = inverseOuterDimsPerm[srcPos];
4420 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4421 ShapedType::isDynamic(destShape[destPos])) {
4424 int64_t size = srcShape[srcPos];
4425 if (ShapedType::isDynamic(size))
4426 size = destShape[destPos];
4427 srcShape[srcPos] = size;
4428 destShape[destPos] = size;
4429 changeNeeded =
true;
4431 return changeNeeded;
4434 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4436 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4437 if (unPackOp.getSourceType() != packOp.getDestType())
4439 if (packOp.getPaddingValue() ||
4443 rewriter.
replaceOp(packOp, unPackOp.getSource());
4450 packOp.getPaddingValueMutable().clear();
4459 Value source = packOp.getSource();
4460 if (srcShape != packOp.getSourceType().getShape()) {
4461 auto newSrcType = packOp.getSourceType().clone(srcShape);
4463 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4465 Value dest = packOp.getDest();
4466 RankedTensorType originalResultType = packOp.getDestType();
4467 bool needUpdateDestType = (destShape != originalResultType.getShape());
4468 if (needUpdateDestType) {
4469 auto newDestType = packOp.getDestType().clone(destShape);
4471 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4474 packOp.getSourceMutable().assign(source);
4475 packOp.getDestMutable().assign(dest);
4476 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
4479 if (needUpdateDestType) {
4482 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
4491 template <
typename PackOrUnpackOp>
4493 RankedTensorType packedTensorType) {
4494 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4495 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4496 "Function meant for pack/unpack");
4501 int64_t numPackedDims = innerDimsPos.size();
4502 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4503 if (orderedDims != innerDimsPos) {
4509 int64_t packedRank = packedTensorType.getRank();
4519 return llvm::all_of(
4520 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4521 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4524 bool PackOp::isLikePad() {
4525 auto packedTensorType =
4526 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4531 std::optional<Attribute> paddingValue;
4532 if (
auto pad = adaptor.getPaddingValue())
4535 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4536 getDestType(), paddingValue))
4537 return reshapedSource;
4545 void UnPackOp::getAsmResultNames(
4547 setNameFn(getResult(),
"unpack");
4569 ShapedType destType = getDestType();
4570 int64_t destRank = destType.getRank();
4571 return getSourceType().getShape().take_front(destRank);
4575 auto innerDimsPos = getInnerDimsPos();
4576 auto packedShape = getSourceType().getShape();
4579 for (
auto index : innerDimsPos)
4580 res.push_back(packedShape[index]);
4601 assert(innerDimsPos.size() == innerTiles.size() &&
4602 "number of tile sizes specified must match the specified number of "
4603 "original dimensions to be tiled");
4607 build(builder, state, dest.
getType(), source, dest,
4608 outerDimsPerm.empty() ?
nullptr
4626 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4628 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4629 if (srcType.isDynamicDim(i))
4632 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4634 if (!outerDimsPerm.empty()) {
4635 applyPermutationToVector<OpFoldResult>(
4639 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4640 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4642 auto elemType = srcType.getElementType();
4643 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4647 Value transposedSource,
4651 *
this, innerPermutation, outerPermutation);
4652 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4653 metadata.innerDimsPos, metadata.innerTiles,
4654 metadata.outerDimsPerm);
4661 bool changeNeeded =
false;
4662 srcShape.assign(op.getSourceType().getShape().begin(),
4663 op.getSourceType().getShape().end());
4664 destShape.assign(op.getDestType().getShape().begin(),
4665 op.getDestType().getShape().end());
4666 llvm::SmallSetVector<int64_t, 4> innerDims;
4667 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4669 if (!op.getOuterDimsPerm().empty())
4671 int destRank = op.getDestRank();
4672 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4673 if (innerDims.contains(i))
4676 int64_t destPos = i;
4677 if (!inverseOuterDimsPerm.empty())
4678 srcPos = inverseOuterDimsPerm[destPos];
4679 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4680 ShapedType::isDynamic(destShape[destPos])) {
4683 int64_t size = srcShape[srcPos];
4684 if (ShapedType::isDynamic(size))
4685 size = destShape[destPos];
4686 srcShape[srcPos] = size;
4687 destShape[destPos] = size;
4688 changeNeeded =
true;
4690 return changeNeeded;
4693 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
4696 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4697 if (packOp.getSourceType() != unPackOp.getDestType())
4699 if (packOp.getPaddingValue() ||
4703 rewriter.
replaceOp(unPackOp, packOp.getSource());
4707 if (
auto dstStyleOp =
4708 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4709 auto destValue = cast<OpResult>(unPackOp.getDest());
4710 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4712 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4720 Value source = unPackOp.getSource();
4721 if (srcShape != unPackOp.getSourceType().getShape()) {
4722 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4723 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
4724 unPackOp.getSource());
4726 Value dest = unPackOp.getDest();
4727 if (destShape != unPackOp.getDestType().getShape()) {
4728 auto newDestType = unPackOp.getDestType().clone(destShape);
4730 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4733 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4734 unPackOp.getOuterDimsPerm());
4736 unPackOp, unPackOp.getResult().getType(), newOp);
4743 bool UnPackOp::isLikeUnPad() {
4744 RankedTensorType packedTensorType = getSourceType();
4750 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4752 return reshapedSource;
4763 if (isa<InsertSliceOp>(op.getOperation()) ||
4764 isa<LoopLikeOpInterface>(op.getOperation()))
4768 bool hasTensorCastOperand =
4769 llvm::any_of(op->getOpOperands(), [&](
OpOperand &opOperand) {
4770 if (llvm::isa<BlockArgument>(opOperand.get()))
4772 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4773 return castOp && canFoldIntoConsumerOp(castOp);
4776 return hasTensorCastOperand;
4782 newOperands.reserve(op->getNumOperands());
4785 int64_t dpsInitIdx = 0;
4786 for (
OpOperand &opOperand : op->getOpOperands()) {
4787 auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4789 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4790 if (op.isDpsInit(&opOperand) &&
4791 !llvm::isa<MemRefType>(newOperands.back().getType()))
4792 newResTy[dpsInitIdx++] = newOperands.back().getType();
4824 for (
auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4826 .take_back(op.getMixedTiles().size()),
4827 op.getMixedTiles())) {
4828 int64_t shape = std::get<0>(it);
4829 if (shape == ShapedType::kDynamic) {
4830 newMixedTileSizes.push_back(std::get<1>(it));
4835 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4837 newMixedTileSizes.push_back(std::get<1>(it));
4840 assert(tileSize == shape &&
"tile size and dim size don't match!");
4842 newMixedTileSizes.push_back(
4848 PackOp newOp = rewriter.
create<PackOp>(
4849 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4850 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4854 Value oldResult = op.getResult();
4855 Value newResult = newOp.getResult();
4857 ? rewriter.
create<tensor::CastOp>(
4858 op->getLoc(), oldResult.
getType(), newResult)
4899 auto newOp =
clone(rewriter, op, newResultTypes, newOperands);
4902 replacements.reserve(newOp->getNumResults());
4903 for (
auto [oldResult, newResult] :
4904 llvm::zip(op->getResults(), newOp->getResults())) {
4905 if (newResult.
getType() != oldResult.getType()) {
4906 replacements.push_back(rewriter.
create<tensor::CastOp>(
4907 op->getLoc(), oldResult.getType(), newResult));
4909 replacements.push_back(newResult);
4922 void TensorDialect::getCanonicalizationPatterns(
4932 #define GET_OP_CLASSES
4933 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< Value > getNewOperands(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, ArrayRef< int64_t > indices, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
bool foldTensorCastPrecondition(DestinationStyleOpInterface op)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(const SmallVectorImpl< OpFoldResult > &mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, Builder &b)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming tensor::PackOp op if the tensor.cast has source that is more ...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)