27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/SmallBitVector.h"
30 #include "llvm/ADT/StringRef.h"
42 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
44 if (complex::ConstantOp::isBuildableWith(value, type))
45 return builder.
create<complex::ConstantOp>(loc, type,
46 llvm::cast<ArrayAttr>(value));
52 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
54 if (tensorType.isDynamicDim(dim))
55 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
62 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
64 for (int64_t i = 0; i < tensorType.getRank(); ++i)
71 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
72 assert(tensorType &&
"expected tensor type");
76 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
78 return destOp.getTiedOpOperand(opResult)->get();
86 if (!tensorType.hasStaticShape()) {
94 for (int64_t sz : tensorType.getShape())
100 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
108 if (llvm::isa<TensorType>(opResult.getType())) {
112 result.push_back(*destination);
119 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
120 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
121 return rtp1.getShape() == rtp2.getShape() &&
122 rtp1.getElementType() == rtp2.getElementType();
132 llvm::SmallBitVector droppedDims(mixedSizes.size());
133 int64_t shapePos = 0;
135 for (
const auto &size :
enumerate(mixedSizes)) {
137 bool isStaticUnitSize =
139 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
141 if (shapePos ==
static_cast<int64_t
>(reducedShape.size())) {
144 assert(isStaticUnitSize &&
"expected unit dim");
145 droppedDims.set(size.index());
150 if (!isStaticUnitSize) {
156 if (reducedShape[shapePos] == 1) {
162 droppedDims.set(size.index());
165 assert(shapePos ==
static_cast<int64_t
>(reducedShape.size()) &&
166 "dimension mismatch");
175 if (inputs.size() != 1 || outputs.size() != 1)
177 Type a = inputs.front(), b = outputs.front();
178 auto aT = dyn_cast<TensorType>(a);
179 auto bT = dyn_cast<TensorType>(b);
183 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
198 auto tensorBitcastOperand =
199 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
200 if (!tensorBitcastOperand)
203 auto resultType = cast<TensorType>(tensorBitcast.getType());
204 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
205 tensorBitcastOperand.getOperand());
214 results.
add<ChainedTensorBitcast>(context);
222 setNameFn(getResult(),
"cast");
228 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
229 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
232 if (!sourceType || !targetType)
236 if (sourceType.getElementType() != targetType.getElementType())
240 if (sourceType.getRank() != targetType.getRank())
244 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
245 if (!ShapedType::isDynamic(std::get<0>(t)) &&
246 ShapedType::isDynamic(std::get<1>(t)))
282 castOp.getSource().getType());
317 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
319 operand.set(castOp.getOperand());
327 if (inputs.size() != 1 || outputs.size() != 1)
329 Type a = inputs.front(), b = outputs.front();
330 auto aT = llvm::dyn_cast<TensorType>(a);
331 auto bT = llvm::dyn_cast<TensorType>(b);
335 if (aT.getElementType() != bT.getElementType())
351 int64_t rank = one.getRank();
352 if (rank != two.getRank())
357 for (int64_t i = 0; i < rank; ++i) {
358 if (one.isDynamicDim(i)) {
359 join.push_back(two.getDimSize(i));
362 if (two.isDynamicDim(i)) {
363 join.push_back(one.getDimSize(i));
366 if (one.getDimSize(i) != two.getDimSize(i))
368 join.push_back(one.getDimSize(i));
382 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
384 if (!tensorCastOperand)
388 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
389 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
390 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
404 auto newJoin =
joinShapes(sourceType, resultType);
405 if (firstJoin != newJoin)
408 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
409 tensorCastOperand.getOperand());
431 auto extractOperand =
432 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
435 auto rankedResultType =
436 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
437 if (!rankedResultType)
441 rankedResultType.getShape() ==
442 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
448 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
450 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
451 if (dimMask && dimMask->count(i))
453 int64_t dim = rankedResultType.getShape()[dimIndex++];
454 if (ShapedType::isDynamic(dim))
456 sizes[i] = rewriter.getIndexAttr(dim);
459 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
460 tensorCast, rankedResultType, extractOperand.getSource(),
461 extractOperand.getMixedOffsets(), sizes,
462 extractOperand.getMixedStrides());
471 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
479 setNameFn(getResult(),
"dim");
485 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
486 build(builder, result, source, indexValue);
489 std::optional<int64_t> DimOp::getConstantIndex() {
498 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
499 if (!rankedSourceType)
509 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
514 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
520 int64_t indexVal = index.getInt();
521 if (indexVal < 0 || indexVal >= tensorType.getRank())
525 if (!tensorType.isDynamicDim(index.getInt())) {
527 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
530 Operation *definingOp = getSource().getDefiningOp();
533 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
535 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
538 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
541 auto dynExtents = fromElements.getDynamicExtents().begin();
542 for (
auto dim : resultType.getShape().take_front(index.getInt()))
543 if (ShapedType::isDynamic(dim))
546 return Value{*dynExtents};
550 unsigned unsignedIndex = index.getValue().getZExtValue();
552 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
555 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
556 sliceOp.isDynamicSize(unsignedIndex)) {
557 return {sliceOp.getDynamicSize(unsignedIndex)};
575 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
578 Value newSource = castOp.getOperand();
591 auto source = dimOp.getSource();
592 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
596 auto resultIndex = source.cast<
OpResult>().getResultNumber();
597 auto initOperand = destOp.getDpsInitOperand(resultIndex);
600 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
608 results.
add<DimOfCastOp, DimOfDestStyleOp>(context);
618 assert(all_of(staticShape,
619 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
620 "expected only static sizes");
621 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
628 build(builder, result, tensorType, dynamicSizes);
637 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
641 if (getType().getNumDynamicDims() !=
643 return emitOpError(
"incorrect number of dynamic sizes, has ")
645 << getType().getNumDynamicDims();
654 for (int64_t i = 0; i < getType().getRank(); ++i) {
655 if (getType().isDynamicDim(i)) {
658 reifiedReturnShapes[0][i] = builder.
getIndexAttr(getType().getDimSize(i));
664 Value EmptyOp::getDynamicSize(
unsigned idx) {
665 assert(getType().isDynamicDim(idx) &&
"expected dynamic dim");
667 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
668 if (getType().isDynamicDim(i))
677 for (int64_t i = 0; i < getType().getRank(); ++i) {
678 if (getType().isDynamicDim(i)) {
681 result.push_back(b.getIndexAttr(getType().
getShape()[i]));
705 op.getType().getShape().end());
710 bool changedType =
false;
711 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
712 if (op.getType().isDynamicDim(i)) {
713 Value dynamicSize = op.getDynamicSizes()[ctr++];
715 if (cst.has_value()) {
719 staticShape[i] = *cst;
722 dynamicSizes.push_back(dynamicSize);
732 staticShape, op.getType().getElementType(), op.getType().getEncoding());
734 rewriter.
create<EmptyOp>(op.
getLoc(), tensorType, dynamicSizes);
745 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
746 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
747 if (!emptyTensorOp || !maybeConstantIndex)
749 if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
752 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
779 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
784 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
788 newMixedSizes.reserve(currMixedSizes.size());
789 assert(resultShape.size() == currMixedSizes.size() &&
790 "mismatch in result shape and sizes of empty op");
791 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
792 int64_t newDim = std::get<0>(it);
796 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
797 if (ShapedType::isDynamic(newDim) ||
798 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
803 producer,
"mismatch in static value of shape of empty tensor "
804 "result and cast result");
806 newMixedSizes.push_back(attr);
812 if (!ShapedType::isDynamic(newDim)) {
819 newMixedSizes.push_back(currDim);
824 resultType.getElementType());
833 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
834 ReplaceEmptyTensorStaticShapeDims>(context);
856 auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
859 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
862 extract, tensorCast.getSource(), extract.getIndices());
869 void ExtractOp::getAsmResultNames(
871 setNameFn(getResult(),
"extracted");
876 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
877 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
878 return emitOpError(
"incorrect number of indices for extract_element");
885 if (
Attribute tensor = adaptor.getTensor())
886 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
887 return splatTensor.getSplatValue<
Attribute>();
891 for (
Attribute indice : adaptor.getIndices()) {
892 if (!indice || !llvm::isa<IntegerAttr>(indice))
894 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
898 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
899 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
900 auto rank = tensorType.getRank();
901 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
905 for (
int i = rank - 1; i >= 0; --i) {
907 stride *= tensorType.getDimSize(i);
908 flatIndex += indices[i] * stride;
912 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
915 return fromElementsOp.getElements()[flatIndex];
919 if (
Attribute tensor = adaptor.getTensor()) {
920 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
921 if (elementsAttr && elementsAttr.isValidIndex(indices))
922 return elementsAttr.getValues<
Attribute>()[indices];
930 results.
add<ExtractFromTensorCast>(context);
937 void FromElementsOp::getAsmResultNames(
939 setNameFn(getResult(),
"from_elements");
944 assert(!elements.empty() &&
"expected at least one element");
946 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
947 build(builder, result, resultType, elements);
950 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
951 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
974 struct ExtractElementFromIndexCast
981 auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
987 auto newExtract = rewriter.
create<tensor::ExtractOp>(
988 loc, elementTy, indexCast.getIn(), extract.getIndices());
1001 results.
add<ExtractElementFromIndexCast>(context);
1008 void GatherOp::getAsmResultNames(
1010 setNameFn(getResult(),
"gather");
1025 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1026 RankedTensorType indicesType,
1030 resultShape.reserve(resultShape.size() + sourceType.getRank());
1031 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1032 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1034 resultShape.push_back(1);
1037 resultShape.push_back(sourceType.getDimSize(idx));
1044 StringRef gatherOrScatter, StringRef sourceOrDest) {
1046 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1048 int64_t numGatherDims = dims.size();
1049 if (numGatherDims > rank)
1051 <<
"_dims overflow " << sourceOrDest <<
" rank";
1052 for (int64_t val : dims) {
1055 <<
"_dims value must be non-negative";
1058 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1060 for (int64_t i = 1; i < numGatherDims; ++i) {
1061 if (dims[i - 1] >= dims[i])
1063 <<
"_dims values must be strictly increasing";
1069 int64_t sourceRank = getSourceType().getRank();
1072 "gather",
"source")))
1075 RankedTensorType expectedResultType = GatherOp::inferResultType(
1076 getSourceType(), getIndicesType(), gatherDims,
false);
1077 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1078 getSourceType(), getIndicesType(), gatherDims,
true);
1079 if (getResultType() != expectedResultType &&
1080 getResultType() != expectedRankReducedResultType) {
1081 return emitOpError(
"result type "
1084 << expectedResultType <<
" or its rank-reduced variant "
1085 << expectedRankReducedResultType <<
" (got: " << getResultType()
1096 void InsertOp::getAsmResultNames(
1098 setNameFn(getResult(),
"inserted");
1103 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1104 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1105 return emitOpError(
"incorrect number of indices");
1113 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1114 if (scalar == splatDest.getSplatValue<
Attribute>())
1123 void GenerateOp::getAsmResultNames(
1125 setNameFn(getResult(),
"generated");
1132 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1133 if (getType().isDynamicDim(dim)) {
1134 reifiedReturnShapes[0][dim] = getOperand(idx++);
1136 reifiedReturnShapes[0][dim] =
1148 auto operandsIt = dynamicExtents.begin();
1149 for (int64_t dim : resultType.
getShape()) {
1150 if (!ShapedType::isDynamic(dim)) {
1151 newShape.push_back(dim);
1156 newShape.push_back(ShapedType::kDynamic);
1157 newOperands.push_back(*operandsIt++);
1160 newShape.push_back(index.getSExtValue());
1168 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1169 if (getNumOperands() != resultType.getNumDynamicDims())
1170 return emitError(
"must have as many index operands as dynamic extents "
1171 "in the result type");
1176 for (int64_t newdim : newShape) {
1177 if (newdim < 0 && !ShapedType::isDynamic(newdim))
1178 return emitError(
"tensor dimensions must be non-negative");
1184 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1186 if (!llvm::all_of(getBody().getArgumentTypes(),
1188 return emitError(
"all body arguments must be index");
1189 if (getBody().getNumArguments() != resultTy.getRank())
1190 return emitError(
"must have one body argument per input dimension");
1193 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1195 if (yieldOp.getValue().getType() != resultTy.getElementType())
1197 "body must be terminated with a `yield` operation of the tensor "
1203 void GenerateOp::build(
1207 build(b, result, resultTy, dynamicExtents);
1212 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1216 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1229 LogicalResult matchAndRewrite(GenerateOp tensorFromElements,
1232 llvm::cast<RankedTensorType>(tensorFromElements.getResult().getType());
1234 if (resultType.hasStaticShape())
1238 tensorFromElements.getDynamicExtents();
1243 for (int64_t newdim : newShape) {
1247 if (newdim < 0 && !ShapedType::isDynamic(newdim))
1251 if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
1254 auto loc = tensorFromElements.getLoc();
1255 auto newOp = rewriter.
create<GenerateOp>(
1259 newOp.getBody().begin());
1277 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1282 auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
1287 Block *body = &tensorFromElements.getBody().
front();
1290 rewriter.
clone(op, mapping);
1304 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1311 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1312 setNameFn(getResult(),
"rank");
1317 auto type = getOperand().getType();
1318 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1319 if (shapedType && shapedType.hasRank())
1321 return IntegerAttr();
1328 void ReshapeOp::getAsmResultNames(
1330 setNameFn(getResult(),
"reshape");
1334 int64_t numElements = 1;
1335 for (
auto dim : type.getShape())
1341 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1342 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1345 return emitOpError(
"element types of source and destination tensor "
1346 "types should be the same");
1349 llvm::cast<RankedTensorType>(
getShape().getType()).getDimSize(0);
1350 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1351 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1353 if (resultRankedType) {
1354 if (operandRankedType && resultRankedType.hasStaticShape() &&
1355 operandRankedType.hasStaticShape()) {
1357 return emitOpError(
"source and destination tensor should have the "
1358 "same number of elements");
1360 if (ShapedType::isDynamic(shapeSize))
1361 return emitOpError(
"cannot use shape operand with dynamic length to "
1362 "reshape to statically-ranked tensor type");
1363 if (shapeSize != resultRankedType.getRank())
1365 "length of shape operand differs from the result's tensor rank");
1374 void CollapseShapeOp::getAsmResultNames(
1376 setNameFn(getResult(),
"collapsed");
1379 void ExpandShapeOp::getAsmResultNames(
1381 setNameFn(getResult(),
"expanded");
1384 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1385 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1386 "invalid resultDim");
1388 if (llvm::is_contained(it.value(), resultDim))
1390 llvm_unreachable(
"could not find reassociation group");
1398 getReassociationIndices());
1406 getReassociationIndices());
1409 RankedTensorType CollapseShapeOp::inferCollapsedType(
1411 return inferCollapsedType(
1413 type.getContext(), reassociation)));
1419 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1421 auto shape = type.getShape();
1423 newShape.reserve(reassociation.size());
1428 unsigned currentDim = 0;
1430 unsigned dim = m.getNumResults();
1431 auto band = shape.slice(currentDim, dim);
1433 if (llvm::is_contained(band, ShapedType::kDynamic))
1434 size = ShapedType::kDynamic;
1436 for (
unsigned d = 0; d < dim; ++d)
1437 size *= shape[currentDim + d];
1438 newShape.push_back(size);
1448 auto resultType = inferCollapsedType(
1449 llvm::cast<RankedTensorType>(src.
getType()),
1452 build(b, result, resultType, src, attrs);
1457 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1458 TensorReshapeOp, ExpandShapeOp>::value>
1460 RankedTensorType expandedType,
1461 RankedTensorType collapsedType) {
1466 auto maps = op.getReassociationMaps();
1467 RankedTensorType expectedType =
1468 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1470 return op.
emitOpError(
"expected collapsed type to be ")
1471 << expectedType <<
", but got " << collapsedType;
1476 auto srcType = getSrcType();
1477 auto resultType = getResultType();
1478 if (srcType.getRank() >= resultType.getRank())
1479 return emitOpError(
"expected rank expansion, but found source rank ")
1480 << srcType.getRank() <<
" >= result rank " << resultType.getRank();
1486 auto srcType = getSrcType();
1487 auto resultType = getResultType();
1488 if (srcType.getRank() <= resultType.getRank())
1489 return emitOpError(
"expected rank reduction, but found source rank ")
1490 << srcType.getRank() <<
" <= result rank " << resultType.getRank();
1498 template <
typename TensorReshapeOp>
1509 reshapeOp.getResultType(), attr.
getRawData());
1516 template <
typename TensorReshapeOp>
1523 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1528 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1535 template <
typename TensorReshapeOp>
1541 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1545 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1547 if (!shapedTy.hasStaticShape())
1551 fromElements.getElements());
1560 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1562 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1566 RankedTensorType srcType =
1567 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1568 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1569 srcType, collapseShapeOp.getReassociationMaps());
1571 if (newResultType == collapseShapeOp.getResultType()) {
1573 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1576 auto newOp = rewriter.
create<CollapseShapeOp>(
1577 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1578 collapseShapeOp.getReassociation());
1580 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1591 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1596 std::optional<int64_t> dim = dimOp.getConstantIndex();
1597 if (!dim.has_value())
1601 RankedTensorType resultType = expandShapeOp.getResultType();
1602 if (!resultType.isDynamicDim(*dim))
1606 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1612 for (int64_t d : grp) {
1614 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1615 product *= resultType.getDimSize(d);
1621 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1625 dimOp, expr.floorDiv(
product), srcDimSz);
1635 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1636 if (!collapseShapeOp)
1640 std::optional<int64_t> dim = dimOp.getConstantIndex();
1641 if (!dim.has_value())
1645 RankedTensorType resultType = collapseShapeOp.getResultType();
1646 if (!resultType.isDynamicDim(*dim))
1651 collapseShapeOp.getReassociationIndices()[*dim];
1658 srcDimSizes.push_back(rewriter.
create<DimOp>(
1659 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1674 FoldReshapeWithConstant<ExpandShapeOp>,
1675 FoldReshapeWithSplat<ExpandShapeOp>,
1676 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1677 FoldDimOfCollapseShape>(context);
1685 FoldReshapeWithConstant<CollapseShapeOp>,
1686 FoldReshapeWithSplat<CollapseShapeOp>,
1687 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1691 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1692 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
1693 adaptor.getOperands());
1696 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1697 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
1698 adaptor.getOperands());
1705 void ExtractSliceOp::getAsmResultNames(
1707 setNameFn(getResult(),
"extracted_slice");
1713 RankedTensorType ExtractSliceOp::inferResultType(
1719 assert(
static_cast<int64_t
>(staticSizes.size()) ==
1720 sourceTensorType.getRank() &&
1721 "unexpected staticSizes not equal to rank of source");
1725 RankedTensorType ExtractSliceOp::inferResultType(
1733 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
1734 staticSizes, staticStrides);
1745 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1746 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1750 auto inferredType = llvm::cast<RankedTensorType>(
1751 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
1752 int rankDiff = inferredType.getRank() - desiredResultRank;
1754 auto shape = inferredType.getShape();
1755 llvm::SmallBitVector dimsToProject =
1759 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1760 if (!dimsToProject.test(pos))
1761 projectedShape.push_back(shape[pos]);
1765 return inferredType;
1768 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1769 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1777 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1778 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1785 RankedTensorType resultType,
Value source,
1795 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
1798 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
1799 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
1801 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1815 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1824 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1830 RankedTensorType resultType,
Value source,
1839 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1846 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
1851 RankedTensorType expectedType) {
1856 return op->
emitError(
"expected rank to be smaller or equal to ")
1857 <<
"the other rank. ";
1859 return op->
emitError(
"expected type to be ")
1860 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
1862 return op->
emitError(
"expected element type to be ")
1863 << expectedType.getElementType();
1865 llvm_unreachable(
"unexpected extract_slice op verification result");
1872 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
1873 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
1885 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
1886 assert(sourceTensorType &&
"not a ranked tensor type");
1887 auto sourceShape = sourceTensorType.getShape();
1888 if (sourceShape.equals(desiredShape))
1890 auto maybeRankReductionMask =
1892 if (!maybeRankReductionMask)
1901 reifiedReturnShapes.resize(1);
1902 reifiedReturnShapes[0].reserve(getType().getRank());
1905 for (
const auto &size :
enumerate(mixedSizes)) {
1906 if (droppedDims.test(size.index()))
1908 reifiedReturnShapes[0].push_back(size.value());
1929 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
1936 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
1937 return matchPattern(operand, matchConstantIndex());
1941 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
1950 Value newResult = rewriter.
create<ExtractSliceOp>(
1951 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
1952 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
1953 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
1954 if (newResult.
getType() != sliceOp.getType())
1955 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
1964 template <
typename IterTy,
typename ElemTy>
1969 assert(offsets.size() == sizes.size());
1970 assert(offsets.size() == strides.size());
1971 if (offsets.empty())
1974 int64_t offset = offsets.front();
1975 int64_t size = sizes.front();
1976 int64_t stride = strides.front();
1977 if (offsets.size() == 1) {
1978 for (int64_t i = 0; i < size; ++i, offset += stride)
1979 outValues->push_back(*(values + offset));
1984 for (int64_t i = 0; i < size; ++i, offset += stride) {
1985 auto begin = values + offset * counts.front();
1986 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
1987 offsets.drop_front(), sizes.drop_front(),
1988 strides.drop_front(), outValues);
1995 class ConstantOpExtractSliceFolder final
2000 ConstantOpExtractSliceFolder(
MLIRContext *context,
2003 controlFn(std::move(controlFn)) {}
2016 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2018 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2025 int64_t count = sourceType.getNumElements();
2030 auto offsets = op.getStaticOffsets();
2031 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2033 auto sizes = op.getStaticSizes();
2034 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2036 auto strides = op.getStaticStrides();
2037 if (llvm::is_contained(strides, ShapedType::kDynamic))
2043 counts.reserve(shape.size());
2044 for (int64_t v : shape) {
2046 counts.push_back(count);
2052 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2054 outValues.reserve(sourceType.getNumElements());
2055 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2056 elems.begin(), counts, offsets, sizes, strides, &outValues);
2058 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2060 outValues.reserve(sourceType.getNumElements());
2061 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2062 elems.begin(), counts, offsets, sizes, strides, &outValues);
2085 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2094 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2095 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2103 ExtractSliceOp newOp) {
2104 Value replacement = newOp.getResult();
2105 if (replacement.
getType() != op.getType())
2106 replacement = rewriter.
create<tensor::CastOp>(op.
getLoc(), op.getType(),
2117 ExtractSliceOpCastFolder>(context);
2123 ShapedType shapedType) {
2130 auto shape = shapedType.getShape();
2131 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2145 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2148 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2149 insertOp.isSameAs(extractOp, isSame))
2150 return insertOp.getSource();
2155 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2157 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) {
2158 auto resultType = llvm::cast<ShapedType>(getResult().getType());
2159 if (resultType.hasStaticShape())
2160 return splat.resizeSplat(resultType);
2162 if (getSourceType() == getType() &&
2164 return this->getSource();
2173 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2174 unsigned rank = rankedTensorType.getRank();
2178 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2179 offsets, sizes, strides);
2186 void InsertSliceOp::getAsmResultNames(
2188 setNameFn(getResult(),
"inserted_slice");
2202 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2215 build(b, result, source, dest, offsets, sizes, strides, attrs);
2228 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2234 RankedTensorType srcType, RankedTensorType dstType,
2239 RankedTensorType expected = ExtractSliceOp::inferResultType(
2240 dstType, staticOffsets, staticSizes, staticStrides);
2242 *expectedType = expected;
2248 RankedTensorType expectedType;
2251 getStaticSizes(), getStaticStrides(), &expectedType);
2273 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2276 if (!prevInsertOp ||
2277 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2278 !prevInsertOp.isSameAs(insertOp, isSame))
2281 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2293 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2296 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2297 !extractOp.isSameAs(insertOp, isSame))
2300 return extractOp.getSource();
2304 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2305 getSourceType() == getType() &&
2307 return this->getSource();
2326 template <
typename InsertOpTy>
2327 class InsertSliceOpConstantArgumentFolder final
2345 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2346 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2347 mixedOffsets, mixedSizes, mixedStrides);
2348 Value toInsert = insertSliceOp.getSource();
2349 if (sourceType != insertSliceOp.getSourceType()) {
2354 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2356 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2357 sourceType, toInsert);
2360 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2361 mixedSizes, mixedStrides);
2386 template <
typename InsertOpTy>
2387 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2392 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2393 return matchPattern(operand, matchConstantIndex());
2397 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2398 auto castOp = v.getDefiningOp<tensor::CastOp>();
2400 return std::nullopt;
2401 return castOp.getSource();
2403 std::optional<Value> sourceCastSource =
2404 getSourceOfCastOp(insertSliceOp.getSource());
2405 std::optional<Value> destCastSource =
2406 getSourceOfCastOp(insertSliceOp.getDest());
2407 if (!sourceCastSource && !destCastSource)
2411 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2412 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2413 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2414 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2415 if (!srcType || !dstType)
2418 insertSliceOp.getStaticSizes(),
2419 insertSliceOp.getStaticStrides()) !=
2424 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2425 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2428 bool isParallelInsert =
2429 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2430 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2431 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2432 insertSliceOp.getDestType(),
2461 template <
typename InsertOpTy>
2462 struct InsertSliceOpSourceCastInserter final
2468 RankedTensorType srcType = insertSliceOp.getSourceType();
2469 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2472 srcType.getShape().end());
2473 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2474 if (std::optional<int64_t> constInt =
2476 newSrcShape[i] = *constInt;
2479 RankedTensorType newSrcType =
2481 if (srcType == newSrcType ||
2483 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2495 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2498 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2500 insertSliceOp, cast, insertSliceOp.getDest(),
2501 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2502 insertSliceOp.getMixedStrides());
2514 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2515 InsertSliceOpCastFolder<InsertSliceOp>,
2516 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2523 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
2524 unsigned rank = rankedTensorType.getRank();
2528 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2537 setNameFn(getResult(),
"padded");
2543 Type typeToInfer,
Type typeToInferFrom) {}
2547 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2548 Type &typeToInfer,
Type typeToInferFrom) {
2550 typeToInfer = typeToInferFrom;
2555 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2556 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2558 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2559 if (!expectedType) {
2560 return emitError(
"failed to infer expectedType from sourceType ")
2561 << sourceType <<
", specified resultType is " << resultType;
2563 if (resultType.getRank() != expectedType.getRank()) {
2565 << resultType <<
" does not match the inferred type "
2568 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
2569 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2571 if (expectedType.isDynamicDim(i))
2574 << resultType <<
" does not match the inferred type "
2582 auto ®ion = getRegion();
2583 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2586 return emitError(
"expected the block to have ") << rank <<
" arguments";
2590 if (!en.value().isIndex())
2591 return emitOpError(
"expected block argument ")
2592 << (en.index() + 1) <<
" to be an index";
2597 if (yieldOp.getValue().getType() !=
2599 return emitOpError(
"expected yield type to match shape element type");
2604 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2608 unsigned rank = sourceType.getRank();
2609 if (staticLow.size() != rank)
2610 return RankedTensorType();
2611 if (staticHigh.size() != rank)
2612 return RankedTensorType();
2613 if (!(resultShape.empty() || resultShape.size() == rank))
2614 return RankedTensorType();
2617 for (
auto i : llvm::seq<unsigned>(0, rank)) {
2618 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2619 staticHigh[i] == ShapedType::kDynamic) {
2620 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2623 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2624 assert((resultShape.empty() || size == resultShape[i] ||
2625 resultShape[i] == ShapedType::kDynamic) &&
2626 "mismatch between inferred shape and result shape");
2627 inferredShape.push_back(size);
2638 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2640 resultType = inferResultType(sourceType, staticLow, staticHigh);
2641 build(b, result, resultType, source, low, high,
2650 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2651 unsigned rank = sourceType.getRank();
2653 build(b, result, resultType, source, staticVector, staticVector, low, high,
2661 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2671 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2673 assert(llvm::isa<RankedTensorType>(resultType));
2674 build(b, result, resultType, source, dynamicLow, dynamicHigh,
2684 build(b, result, resultType, source, low, high, nofold, attrs);
2688 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
2695 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
2699 llvm::SmallBitVector PadOp::getPaddedDims() {
2700 llvm::SmallBitVector paddedDims(getSourceType().getRank());
2702 for (
const auto &en :
enumerate(paddingWidths))
2704 paddedDims.set(en.index());
2706 extractPaddedDims(getMixedLowPad());
2707 extractPaddedDims(getMixedHighPad());
2719 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2721 if (padTensorOp.getNofold())
2724 padTensorOp, padTensorOp.getResult().getType(),
2725 padTensorOp.getSource());
2736 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2740 auto newResultType = PadOp::inferResultType(
2741 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
2742 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2743 padTensorOp.getResultType().getShape());
2745 if (newResultType == padTensorOp.getResultType()) {
2747 padTensorOp.getSourceMutable().assign(castOp.getSource());
2750 auto newOp = rewriter.
create<PadOp>(
2751 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2752 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2753 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2756 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2759 padTensorOp, padTensorOp.getResultType(), newOp);
2772 if (!padTensorOp.getResult().hasOneUse())
2775 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2779 tensorCastOp.getDest().getType()))
2782 auto replacementOp = rewriter.
create<PadOp>(
2783 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2784 padTensorOp.getSource(), padTensorOp.getStaticLow(),
2785 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2786 padTensorOp.getHigh(), padTensorOp.getNofold(),
2790 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
2791 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
2836 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
2839 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
2840 if (!outerPadOp || outerPadOp.getNofold())
2842 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
2847 int64_t rank = padOp.getSourceType().getRank();
2848 if (outerSliceOp.getSourceType().getRank() != rank) {
2850 "cannot fold rank-reducing chain");
2854 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
2856 padOp,
"cannot fold non-unit stride ExtractSliceOps");
2860 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
2862 "cannot fold PadOps with low padding");
2867 Value innerValue = padOp.getConstantPaddingValue();
2868 Value outerValue = outerPadOp.getConstantPaddingValue();
2869 if (!innerValue || !outerValue ||
2872 innerAttr != outerAttr) {
2874 padOp,
"cannot fold PadOps with different padding values");
2878 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
2879 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
2880 if (innerDims.anyCommon(outerDims)) {
2882 padOp,
"cannot fold PadOps with common padding dimensions");
2892 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
2893 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
2894 if (!innerDims.test(en.index()) &&
2896 en.value() = outerOffset;
2899 if (!outerDims.test(en.index()) &&
2901 en.value() = innerOffset;
2905 padOp,
"cannot find zero-offset and zero-padding pair");
2915 if (!outerDims.test(en.index()))
2917 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
2918 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
2919 assert(!ShapedType::isDynamic(sourceSize) &&
2920 "expected padded dimension to have a static size");
2923 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
2924 "match the size of the outer padding");
2926 en.value() = outerSliceOp.getMixedSizes()[en.index()];
2932 if (innerDims.test(en.index()))
2933 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
2934 if (outerDims.test(en.index()))
2935 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
2940 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
2941 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
2942 innerSliceOp.getMixedStrides());
2943 auto newPadOp = rewriter.
create<PadOp>(
2944 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
2945 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
2948 newPadOp.getRegion().begin());
2949 rewriter.
replaceOp(padOp, newPadOp.getResult());
2959 Value input = padTensorOp.getSource();
2960 if (!llvm::isa<RankedTensorType>(input.
getType()))
2962 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
2963 auto inputRank = inputDims.size();
2965 auto oldResultType =
2966 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
2970 auto outputDims = oldResultType.getShape();
2974 for (
auto operand : padTensorOp.getLow()) {
2977 constOperandsLow.push_back(ShapedType::kDynamic);
2980 constOperandsLow.push_back(intOp.getExtValue());
2983 for (
auto operand : padTensorOp.getHigh()) {
2986 constOperandsHigh.push_back(ShapedType::kDynamic);
2989 constOperandsHigh.push_back(intOp.getExtValue());
2996 if (inputDims.size() != outputDims.size() ||
2997 inputDims.size() != constLow.size() ||
2998 inputDims.size() != constHigh.size())
3003 for (
size_t i = 0; i < inputRank; i++) {
3004 if (constLow[i] == ShapedType::kDynamic)
3005 constLow[i] = constOperandsLow[lowCount++];
3006 if (constHigh[i] == ShapedType::kDynamic)
3007 constHigh[i] = constOperandsHigh[highCount++];
3015 for (
size_t i = 0; i < inputRank; i++) {
3016 if (outputDims[i] == ShapedType::kDynamic) {
3017 newOutDims.push_back(
3018 (staticLow[i] == ShapedType::kDynamic ||
3019 staticHigh[i] == ShapedType::kDynamic ||
3020 inputDims[i] == ShapedType::kDynamic
3021 ? ShapedType::kDynamic
3022 : inputDims[i] + staticLow[i] + staticHigh[i]));
3024 newOutDims.push_back(outputDims[i]);
3029 llvm::all_of(newOutDims,
3030 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3035 newOutDims, padTensorOp.getType().getElementType());
3036 auto newOp = rewriter.
create<PadOp>(
3037 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3038 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3042 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3054 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3055 FoldOrthogonalPaddings, FoldStaticPadding>(context);
3067 Value PadOp::getConstantPaddingValue() {
3068 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3071 Value padValue = yieldOp.getValue();
3083 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3093 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3094 ParallelCombiningOpInterface parallelCombiningParent =
3095 getParallelCombiningParent();
3096 for (
const auto &it :
3099 if (&nextOp == getOperation())
3100 return parallelCombiningParent.getParentResult(it.index());
3102 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3117 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3131 build(b, result, source, dest, offsets, sizes, strides, attrs);
3145 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3149 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3150 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3151 << *(getOperation()->getParentOp());
3153 RankedTensorType expectedType;
3156 getStaticSizes(), getStaticStrides(), &expectedType);
3160 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3162 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3163 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3164 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3175 void ScatterOp::getAsmResultNames(
3177 setNameFn(getResult(),
"scatter");
3181 int64_t destRank = getDestType().getRank();
3184 "scatter",
"dest")))
3188 return emitOpError(
"requires 'unique' attribute to be set");
3195 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3196 getDestType(), getIndicesType(), scatterDims,
false);
3197 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3198 getDestType(), getIndicesType(), scatterDims,
true);
3199 if (getSourceType() != expectedSourceType &&
3200 getSourceType() != expectedRankReducedSourceType) {
3201 return emitOpError(
"source type "
3204 << expectedSourceType <<
" or its rank-reduced variant "
3205 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3216 void SplatOp::getAsmResultNames(
3218 setNameFn(getResult(),
"splat");
3222 auto constOperand = adaptor.getInput();
3223 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3242 Type newOperandType, ArrayAttr reassociation)
const {
3243 if (operand.
getType() == newOperandType)
3245 return rewriter.
create<tensor::ExpandShapeOp>(loc, newOperandType, operand,
3251 RankedTensorType sourceType = packOp.getSourceType();
3252 RankedTensorType destType = packOp.getDestType();
3253 if (sourceType.getRank() != 1 || packOp.getPaddingValue())
3255 auto reassociation =
3259 Value expanded = insertExpand(
3260 rewriter, packOp.getLoc(), packOp.getSource(), destType,
3270 patterns.
add<SimplifyPackToExandShape>(patterns.
getContext());
3273 template <
typename OpTy>
3277 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3278 "applies to only pack or unpack operations");
3279 int64_t destRank = op.getDestRank();
3281 reifiedReturnShapes[0] =
3286 template <
typename OpTy>
3288 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3289 "applies to only pack or unpack operations");
3293 assert(tiles.size() == dimsToTile.size() &&
3294 "tiles must match indices of dimension to block");
3296 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3297 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3298 return dimAndTileMapping;
3301 template <
typename OpTy>
3303 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3304 "applies to only pack or unpack operations");
3307 unsigned dynamicValIndex = 0;
3308 for (int64_t staticTile : op.getStaticInnerTiles()) {
3309 if (!ShapedType::isDynamic(staticTile))
3312 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3314 return mixedInnerTiles;
3317 template <
typename OpTy>
3319 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3320 "applies to only pack or unpack operations");
3333 size_t dimsPosSize = dimsPos.size();
3334 if (dimsPosSize > rank)
3337 for (int64_t dim : dimsPos)
3338 uniqued.insert(dim);
3339 if (dimsPosSize != uniqued.size())
3341 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3342 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3351 sourceShape.size() == limitShape.size() &&
3352 "expected source shape rank, and limit of the shape to have same rank");
3353 return llvm::all_of(
3354 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3355 int64_t sourceExtent = std::get<0>(it);
3356 int64_t limit = std::get<1>(it);
3357 return ShapedType::isDynamic(sourceExtent) ||
3358 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3362 template <
typename OpTy>
3364 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3365 "applies to only pack or unpack operations");
3366 Operation *op = packOrUnPack.getOperation();
3370 return llvm::any_of(
3376 if (hasZeros(mixedTiles))
3377 return op->
emitError(
"invalid zero tile factor");
3380 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3381 ? packOrUnPack.getSourceType()
3382 : packOrUnPack.getDestType();
3383 size_t unpackedRank = unpackedType.getRank();
3387 return op->
emitError(
"invalid inner_dims_pos vector");
3389 return op->
emitError(
"invalid outer_dims_perm vector");
3390 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3391 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3395 if (mixedTiles.size() > unpackedRank) {
3396 return op->
emitError(
"tiling factors must be less than or equal to the "
3397 "input rank for pack or output rank for unpack");
3399 if (mixedTiles.size() != innerDimsPos.size()) {
3401 "tiling factors must equal the number of dimensions to tile");
3404 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3405 ? packOrUnPack.getDestType()
3406 : packOrUnPack.getSourceType();
3407 size_t packedRank = packedType.getRank();
3409 if (unpackedRank + mixedTiles.size() != packedRank) {
3411 "packed rank must equal unpacked rank + tiling factors");
3417 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3418 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3419 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3420 return op->
emitError(
"the shape of output is not large enough to hold the "
3421 "packed data. Expected at least ")
3422 << expectedPackedType <<
", got " << packedType;
3425 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3427 [](std::tuple<int64_t, OpFoldResult> it) {
3428 std::optional<int64_t> constTileSize =
3429 getConstantIntValue(std::get<1>(it));
3430 int64_t shape = std::get<0>(it);
3431 if (!constTileSize) {
3434 return ShapedType::isDynamic(shape);
3436 if (ShapedType::isDynamic(shape)) {
3443 return shape == constTileSize.value();
3445 return op->
emitError(
"mismatch in inner tile sizes specified and shaped of "
3446 "tiled dimension in the packed type");
3458 struct PackOrUnPackTransposeResult {
3465 template <
typename OpTy>
3466 static PackOrUnPackTransposeResult
3470 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3471 "applies to only pack or unpack operations");
3472 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3473 "some permutation must be non-empty");
3474 PackOrUnPackTransposeResult metadata;
3475 metadata.innerDimsPos =
3477 metadata.innerTiles =
3479 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3480 ? packOrUnPackOp.getSourceRank()
3481 : packOrUnPackOp.getDestRank();
3482 metadata.outerDimsPerm =
3483 packOrUnPackOp.getOuterDimsPerm().empty()
3484 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3486 if (!innerPermutation.empty()) {
3487 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3489 "invalid inner permutation");
3493 if (!outerPermutation.empty()) {
3494 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3496 "invalid outer permutation");
3506 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3507 setNameFn(getResult(),
"pack");
3513 std::optional<Value> paddingValue,
3515 assert(innerDimsPos.size() == innerTiles.size() &&
3516 "number of tile sizes specified must match the specified number of "
3517 "original dimensions to be tiled");
3521 build(builder, state, dest.
getType(), source, dest,
3522 paddingValue ? *paddingValue :
nullptr,
3523 outerDimsPerm.empty() ?
nullptr
3550 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3551 if (ShapedType::isDynamic(inputShape[pos]))
3556 if (inputShape[pos] % (*constantTile) != 0)
3569 auto paddingValue = getPaddingValue();
3572 return emitOpError(
"expected padding_value has ")
3573 << getSourceType().getElementType()
3574 <<
" but got: " << paddingValue.getType();
3577 if (!paddingValue &&
3578 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
3580 return emitOpError(
"invalid tile factor provided. Only full tiles are "
3581 "supported when padding_value is not set");
3591 for (
auto o : ofrs) {
3593 if (llvm::dyn_cast_if_present<Value>(o))
3594 result.push_back(ShapedType::kDynamic);
3609 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3611 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3612 resultShape[tiledDim.value()] = ShapedType::kDynamic;
3615 resultShape[tiledDim.value()] =
ceilDiv(resultShape[tiledDim.value()],
3616 innerTileSizes[tiledDim.index()]);
3620 if (!outerDimsPerm.empty())
3624 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3639 builder, loc, ceilDivExpr,
3640 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3642 if (!outerDimsPerm.empty())
3644 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3649 innerDimsPos, outerDimsPerm);
3655 for (
unsigned i = 0; i < resultDims.size(); ++i) {
3656 if (!ShapedType::isDynamic(resultTypeShape[i]))
3667 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
3672 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3689 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
3690 if (ShapedType::isDynamic(value))
3691 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
3695 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3696 int64_t dimPos = std::get<0>(it);
3698 mixedSizes[dimPos] =
ceilDiv(mixedSizes[dimPos], tileSize);
3700 if (!outerDimsPerm.empty())
3701 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3703 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3704 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3705 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3712 *
this, innerPermutation, outerPermutation);
3713 Value transposedDest =
3714 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3715 metadata.innerDimsPos, metadata.outerDimsPerm);
3716 return b.
create<PackOp>(loc, getSource(), transposedDest,
3717 metadata.innerDimsPos, metadata.innerTiles,
3718 getPaddingValue(), metadata.outerDimsPerm);
3722 template <
typename OpTy>
3724 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3725 "applies to only pack or unpack operations");
3726 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3728 : op.getSourceType();
3730 for (
auto [dimDest,
tile] : llvm::zip(
3731 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3733 if (!constTileSize || ShapedType::isDynamic(dimDest))
3740 if (getPaddingValue())
3755 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3757 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
3763 auto packTiles = packOp.getMixedTiles();
3764 auto unPackTiles = unPackOp.getMixedTiles();
3765 if (packTiles.size() != unPackTiles.size())
3767 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
3776 UnPackOp unPackOp = packOp.getSource().getDefiningOp<UnPackOp>();
3777 if (!unPackOp || unPackOp.getSourceType() != packOp.getDestType())
3779 if (packOp.getPaddingValue() ||
3783 rewriter.
replaceOp(packOp, unPackOp.getSource());
3787 template <
typename PackOrUnpackOp>
3789 RankedTensorType packedTensorType) {
3790 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
3791 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
3792 "Function meant for pack/unpack");
3797 int64_t numPackedDims = innerDimsPos.size();
3798 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
3799 if (orderedDims != innerDimsPos) {
3805 int64_t packedRank = packedTensorType.getRank();
3815 return llvm::all_of(
3816 llvm::seq<int64_t>(0, packedRank - numPackedDims),
3817 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
3820 bool PackOp::isLikePad() {
3821 auto packedTensorType =
3822 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
3830 void UnPackOp::getAsmResultNames(
3832 setNameFn(getResult(),
"unpack");
3869 assert(innerDimsPos.size() == innerTiles.size() &&
3870 "number of tile sizes specified must match the specified number of "
3871 "original dimensions to be tiled");
3875 build(builder, state, dest.
getType(), source, dest,
3876 outerDimsPerm.empty() ?
nullptr
3894 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
3896 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
3897 if (srcType.isDynamicDim(i))
3900 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
3902 if (!outerDimsPerm.empty()) {
3903 applyPermutationToVector<OpFoldResult>(
3907 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
3908 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
3910 auto elemType = srcType.getElementType();
3911 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3915 Value transposedSource,
3919 *
this, innerPermutation, outerPermutation);
3920 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
3921 metadata.innerDimsPos, metadata.innerTiles,
3922 metadata.outerDimsPerm);
3928 PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3929 if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3931 if (packOp.getPaddingValue() ||
3935 rewriter.
replaceOp(unPackOp, packOp.getSource());
3939 bool UnPackOp::isLikeUnPad() {
3940 RankedTensorType packedTensorType = getSourceType();
3971 if (isa<InsertSliceOp>(op.getOperation()))
3976 if (isa<LoopLikeOpInterface>(op.getOperation()))
3980 bool hasTensorCastOperand =
3982 if (llvm::isa<BlockArgument>(opOperand.get()))
3984 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
3985 return castOp && canFoldIntoConsumerOp(castOp);
3987 if (!hasTensorCastOperand)
3997 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.
get());
3998 if (op.isDpsInit(&opOperand) &&
3999 !llvm::isa<MemRefType>(newOperands.back().getType()))
4000 newResultTypes.push_back(newOperands.back().getType());
4004 Operation *newOp =
clone(rewriter, op, newResultTypes, newOperands);
4007 for (
auto [oldResult, newResult] :
4009 if (newResult.
getType() != oldResult.getType()) {
4010 replacements.push_back(rewriter.
create<tensor::CastOp>(
4011 op->
getLoc(), oldResult.getType(), newResult));
4013 replacements.push_back(newResult);
4026 void TensorDialect::getCanonicalizationPatterns(
4035 #define GET_OP_CLASSES
4036 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, const SmallVectorImpl< Value > &sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static void operandsAndShape(TensorType resultType, Operation::operand_range dynamicExtents, SmallVectorImpl< Value > &newOperands, SmallVectorImpl< int64_t > &newShape)
Extract operands and shape from a tensor with dynamic extents.
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the rewriter that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
void populateSimplifyTensorPack(RewritePatternSet &patterns)
Patterns to simplify tensor.pack.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult foldDynamicIndexList(SmallVectorImpl< OpFoldResult > &ofrs)
Returns "success" when any of the elements in ofrs is a constant value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.