28 #include "llvm/ADT/DenseSet.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallBitVector.h"
31 #include "llvm/ADT/StringRef.h"
43 if (
auto op = arith::ConstantOp::materialize(builder, value, type, loc))
45 if (complex::ConstantOp::isBuildableWith(value, type))
46 return builder.
create<complex::ConstantOp>(loc, type,
47 llvm::cast<ArrayAttr>(value));
53 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
55 if (tensorType.isDynamicDim(dim))
56 return builder.
createOrFold<tensor::DimOp>(loc, value, dim);
63 auto tensorType = llvm::cast<RankedTensorType>(value.
getType());
65 for (int64_t i = 0; i < tensorType.getRank(); ++i)
72 auto tensorType = llvm::dyn_cast<TensorType>(opResult.
getType());
73 assert(tensorType &&
"expected tensor type");
77 auto destOp = opResult.
getDefiningOp<DestinationStyleOpInterface>();
79 return destOp.getTiedOpOperand(opResult)->get();
87 if (!tensorType.hasStaticShape()) {
95 for (int64_t sz : tensorType.getShape())
101 b.
create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
109 if (llvm::isa<TensorType>(opResult.getType())) {
113 result.push_back(*destination);
120 if (
auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
121 if (
auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
122 return rtp1.getShape() == rtp2.getShape() &&
123 rtp1.getElementType() == rtp2.getElementType();
133 llvm::SmallBitVector droppedDims(mixedSizes.size());
134 int64_t shapePos = 0;
136 for (
const auto &size :
enumerate(mixedSizes)) {
138 bool isStaticUnitSize =
140 llvm::cast<IntegerAttr>(size.value().get<
Attribute>()).getInt() == 1;
142 if (shapePos ==
static_cast<int64_t
>(reducedShape.size())) {
145 assert(isStaticUnitSize &&
"expected unit dim");
146 droppedDims.set(size.index());
151 if (!isStaticUnitSize) {
157 if (reducedShape[shapePos] == 1) {
163 droppedDims.set(size.index());
166 assert(shapePos ==
static_cast<int64_t
>(reducedShape.size()) &&
167 "dimension mismatch");
174 static RankedTensorType
178 type.getShape().end());
179 assert(type.getNumDynamicDims() ==
180 static_cast<int64_t
>(dynamicSizes.size()) &&
181 "incorrect number of dynamic sizes");
185 for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
186 if (type.isDynamicDim(i)) {
187 Value dynamicSize = dynamicSizes[ctr++];
189 if (cst.has_value()) {
191 if (cst.value() < 0) {
192 foldedDynamicSizes.push_back(dynamicSize);
195 staticShape[i] = *cst;
197 foldedDynamicSizes.push_back(dynamicSize);
211 if (inputs.size() != 1 || outputs.size() != 1)
213 Type a = inputs.front(), b = outputs.front();
214 auto aT = dyn_cast<TensorType>(a);
215 auto bT = dyn_cast<TensorType>(b);
219 if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
234 auto tensorBitcastOperand =
235 tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
236 if (!tensorBitcastOperand)
239 auto resultType = cast<TensorType>(tensorBitcast.getType());
240 rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
241 tensorBitcastOperand.getOperand());
250 results.
add<ChainedTensorBitcast>(context);
258 setNameFn(getResult(),
"cast");
264 auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
265 auto targetType = llvm::dyn_cast<RankedTensorType>(target);
268 if (!sourceType || !targetType)
272 if (sourceType.getElementType() != targetType.getElementType())
276 if (sourceType.getRank() != targetType.getRank())
280 if (sourceType.getEncoding() != targetType.getEncoding())
284 for (
auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
285 if (!ShapedType::isDynamic(std::get<0>(t)) &&
286 ShapedType::isDynamic(std::get<1>(t)))
322 castOp.getSource().getType());
357 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
359 operand.set(castOp.getOperand());
367 if (inputs.size() != 1 || outputs.size() != 1)
369 Type a = inputs.front(), b = outputs.front();
370 auto aT = llvm::dyn_cast<TensorType>(a);
371 auto bT = llvm::dyn_cast<TensorType>(b);
375 if (aT.getElementType() != bT.getElementType())
391 int64_t rank = one.getRank();
392 if (rank != two.getRank())
397 for (int64_t i = 0; i < rank; ++i) {
398 if (one.isDynamicDim(i)) {
399 join.push_back(two.getDimSize(i));
402 if (two.isDynamicDim(i)) {
403 join.push_back(one.getDimSize(i));
406 if (one.getDimSize(i) != two.getDimSize(i))
408 join.push_back(one.getDimSize(i));
422 auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
424 if (!tensorCastOperand)
428 llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
429 auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
430 auto resultType = llvm::cast<TensorType>(tensorCast.getType());
444 auto newJoin =
joinShapes(sourceType, resultType);
445 if (firstJoin != newJoin)
448 rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
449 tensorCastOperand.getOperand());
471 auto extractOperand =
472 tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
475 auto rankedResultType =
476 llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
477 if (!rankedResultType)
481 rankedResultType.getShape() ==
482 llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
488 extractOperand.getStaticSizes(), extractOperand.getType().getShape());
490 for (
size_t i = 0, e = sizes.size(); i < e; i++) {
491 if (dimMask && dimMask->count(i))
493 int64_t dim = rankedResultType.getShape()[dimIndex++];
494 if (ShapedType::isDynamic(dim))
496 sizes[i] = rewriter.getIndexAttr(dim);
499 rewriter.replaceOpWithNewOp<ExtractSliceOp>(
500 tensorCast, rankedResultType, extractOperand.getSource(),
501 extractOperand.getMixedOffsets(), sizes,
502 extractOperand.getMixedStrides());
511 results.
add<ChainedTensorCast, TensorCastExtractSlice>(context);
518 RankedTensorType ConcatOp::inferResultType(int64_t dim,
TypeRange inputTypes) {
519 assert(!inputTypes.empty() &&
"cannot concatenate 0 tensors");
521 llvm::to_vector<4>(llvm::map_range(inputTypes, [](
Type type) {
522 return llvm::cast<RankedTensorType>(type);
524 int64_t concatRank = tensorTypes[0].getRank();
527 assert(dim >= 0 && dim < concatRank &&
"Invalid concatenation dim");
530 for (int64_t i = 0, e = concatRank; i < e; ++i) {
534 for (
auto tensorType : tensorTypes)
539 for (
auto tensorType : tensorTypes)
542 sizes[dim] = concatSize.asInteger();
549 inferResultType(dim, inputs.
getTypes());
550 assert(
succeeded(resultType) &&
"failed to infer concatenation result type");
551 build(builder, result, *resultType, dim, inputs);
555 if (getInputs().size() < 1)
556 return emitOpError(
"requires at least one input");
559 for (
auto input : getInputs())
560 inputTypes.push_back(cast<RankedTensorType>(input.getType()));
562 RankedTensorType resultType = getResultType();
563 int64_t resultRank = getRank();
564 if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
565 return type.getRank() != resultRank;
567 return emitOpError(
"rank of concatenated inputs must match result rank");
569 Type resultElementType = resultType.getElementType();
570 if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
571 return type.getElementType() != resultElementType;
573 return emitOpError(
"inputs and result element type must match");
575 int64_t dim = getDim();
576 if (dim >= resultRank)
577 return emitOpError(
"concatenation dim must be less than the tensor rank");
580 for (int64_t i = 0, e = resultRank; i < e; ++i) {
584 for (
auto tensorType : inputTypes) {
588 return emitOpError(
"static concatenation size mismatch along ")
589 <<
"non-concatenated dimension " << i;
595 for (
auto tensorType : inputTypes)
598 sizes[dim] = concatSize.asInteger();
599 auto inferredResultType =
602 for (
auto [inferredSize, actualSize] :
603 llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
604 bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
605 ShapedType::isDynamic(actualSize);
606 if (!hasDynamic && inferredSize != actualSize)
607 return emitOpError(
"result type ")
608 << resultType <<
"does not match inferred shape "
609 << inferredResultType <<
" static sizes";
619 int64_t dim = getDim();
620 RankedTensorType inferredResultType = inferResultType(dim, inputs.
getTypes());
622 Value init = inputs[0];
623 int64_t rank = getType().getRank();
630 for (int64_t i = 0; i < rank; ++i) {
633 if (!getType().isDynamicDim(i)) {
634 reifiedReturnShapes[0][i] = builder.
getIndexAttr(getType().getDimSize(i));
635 }
else if (!inferredResultType.isDynamicDim(i)) {
638 builder.
getIndexAttr(inferredResultType.getDimSize(i)));
640 reifiedReturnShapes[0][i] =
641 builder.
create<tensor::DimOp>(init.
getLoc(), init, i).getResult();
645 if (getType().isDynamicDim(dim)) {
653 builder.
createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
661 reifiedReturnShapes[0][dim] =
667 void ConcatOp::getAsmResultNames(
669 setNameFn(getResult(),
"concat");
674 if (inputs.size() == 1 && inputs[0].
getType() == getResultType())
686 if (concatOp.getInputs().size() != 1)
689 concatOp.getInputs()[0]);
697 results.
add<SingleInputConcatOp>(context);
705 setNameFn(getResult(),
"dim");
711 Value indexValue = builder.
create<arith::ConstantIndexOp>(loc, index);
712 build(builder, result, source, indexValue);
715 std::optional<int64_t> DimOp::getConstantIndex() {
724 auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
725 if (!rankedSourceType)
736 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
741 auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
747 int64_t indexVal = index.getInt();
748 if (indexVal < 0 || indexVal >= tensorType.getRank())
752 if (!tensorType.isDynamicDim(index.getInt())) {
754 return builder.
getIndexAttr(tensorType.getShape()[index.getInt()]);
757 Operation *definingOp = getSource().getDefiningOp();
760 if (
auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
762 llvm::cast<RankedTensorType>(fromElements.getResult().getType());
765 assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
768 auto dynExtents = fromElements.getDynamicExtents().begin();
769 for (
auto dim : resultType.getShape().take_front(index.getInt()))
770 if (ShapedType::isDynamic(dim))
773 return Value{*dynExtents};
777 unsigned unsignedIndex = index.getValue().getZExtValue();
779 if (
auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
782 if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
783 sliceOp.isDynamicSize(unsignedIndex)) {
784 return {sliceOp.getDynamicSize(unsignedIndex)};
802 auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
805 Value newSource = castOp.getOperand();
818 auto source = dimOp.getSource();
819 auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
823 auto resultIndex = source.cast<
OpResult>().getResultNumber();
824 auto *initOperand = destOp.getDpsInitOperand(resultIndex);
827 dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
839 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
849 rewriter.
create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
850 if (extract.
getType() != dim.getType())
852 rewriter.
create<arith::IndexCastOp>(loc, dim.getType(), extract);
861 results.
add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
871 assert(all_of(staticShape,
872 [](int64_t sz) {
return !ShapedType::isDynamic(sz); }) &&
873 "expected only static sizes");
874 build(builder, result, staticShape, elementType,
ValueRange{}, encoding);
881 build(builder, result, tensorType, dynamicSizes);
890 build(builder, result, staticShape, elementType, dynamicSizes, encoding);
894 if (getType().getNumDynamicDims() !=
896 return emitOpError(
"incorrect number of dynamic sizes, has ")
898 << getType().getNumDynamicDims();
907 for (int64_t i = 0; i < getType().getRank(); ++i) {
908 if (getType().isDynamicDim(i)) {
911 reifiedReturnShapes[0][i] = builder.
getIndexAttr(getType().getDimSize(i));
917 Value EmptyOp::getDynamicSize(
unsigned idx) {
918 assert(getType().isDynamicDim(idx) &&
"expected dynamic dim");
920 for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
921 if (getType().isDynamicDim(i))
930 for (int64_t i = 0; i < getType().getRank(); ++i) {
931 if (getType().isDynamicDim(i)) {
934 result.push_back(b.getIndexAttr(getType().
getShape()[i]));
959 op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
962 if (foldedTensorType == op.getType())
965 auto newOp = rewriter.
create<EmptyOp>(op.
getLoc(), foldedTensorType,
977 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
978 auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
979 if (!emptyTensorOp || !maybeConstantIndex)
981 if (!emptyTensorOp.getType().isDynamicDim(*maybeConstantIndex))
984 emptyTensorOp.getDynamicSize(*maybeConstantIndex));
1011 auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
1016 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
1020 newMixedSizes.reserve(currMixedSizes.size());
1021 assert(resultShape.size() == currMixedSizes.size() &&
1022 "mismatch in result shape and sizes of empty op");
1023 for (
auto it : llvm::zip(resultShape, currMixedSizes)) {
1024 int64_t newDim = std::get<0>(it);
1028 if (
auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
1029 if (ShapedType::isDynamic(newDim) ||
1030 newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
1035 producer,
"mismatch in static value of shape of empty tensor "
1036 "result and cast result");
1038 newMixedSizes.push_back(attr);
1044 if (!ShapedType::isDynamic(newDim)) {
1045 newMixedSizes.push_back(rewriter.
getIndexAttr(newDim));
1051 newMixedSizes.push_back(currDim);
1056 resultType.getElementType());
1065 results.
add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
1066 ReplaceEmptyTensorStaticShapeDims>(context);
1075 std::optional<Attribute> cst = std::nullopt) {
1076 if (source && source.
isSplat() && result.hasStaticShape() &&
1097 struct ExtractFromTensorCast :
public OpRewritePattern<tensor::ExtractOp> {
1102 auto tensorCast = extract.getTensor().
getDefiningOp<tensor::CastOp>();
1105 if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
1108 extract, tensorCast.getSource(), extract.getIndices());
1115 void ExtractOp::getAsmResultNames(
1117 setNameFn(getResult(),
"extracted");
1122 auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
1123 if (tensorType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1124 return emitOpError(
"incorrect number of indices for extract_element");
1131 if (
Attribute tensor = adaptor.getTensor())
1132 if (
auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
1133 return splatTensor.getSplatValue<
Attribute>();
1137 for (
Attribute indice : adaptor.getIndices()) {
1138 if (!indice || !llvm::isa<IntegerAttr>(indice))
1140 indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
1144 if (
auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
1145 auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
1146 auto rank = tensorType.getRank();
1147 assert(
static_cast<int64_t
>(indices.size()) == tensorType.getRank() &&
1151 for (
int i = rank - 1; i >= 0; --i) {
1152 flatIndex += indices[i] * stride;
1153 stride *= tensorType.getDimSize(i);
1157 if (
static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
1160 return fromElementsOp.getElements()[flatIndex];
1164 if (
Attribute tensor = adaptor.getTensor()) {
1165 auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
1166 if (elementsAttr && elementsAttr.isValidIndex(indices))
1167 return elementsAttr.getValues<
Attribute>()[indices];
1175 results.
add<ExtractFromTensorCast>(context);
1182 void FromElementsOp::getAsmResultNames(
1184 setNameFn(getResult(),
"from_elements");
1189 assert(!elements.empty() &&
"expected at least one element");
1191 {
static_cast<int64_t
>(elements.size())}, elements.front().
getType());
1192 build(builder, result, resultType, elements);
1195 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
1196 if (!llvm::is_contained(adaptor.getElements(),
nullptr))
1219 struct ExtractElementFromIndexCast
1226 auto indexCast = extract.getTensor().
getDefiningOp<arith::IndexCastOp>();
1232 auto newExtract = rewriter.
create<tensor::ExtractOp>(
1233 loc, elementTy, indexCast.getIn(), extract.getIndices());
1246 results.
add<ExtractElementFromIndexCast>(context);
1253 void GatherOp::getAsmResultNames(
1255 setNameFn(getResult(),
"gather");
1270 RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
1271 RankedTensorType indicesType,
1275 resultShape.reserve(resultShape.size() + sourceType.getRank());
1276 for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
1277 if (std::binary_search(gatherDims.begin(), gatherDims.end(), idx)) {
1279 resultShape.push_back(1);
1282 resultShape.push_back(sourceType.getDimSize(idx));
1289 StringRef gatherOrScatter, StringRef sourceOrDest) {
1291 return op->
emitOpError(gatherOrScatter) <<
"_dims must be non-empty";
1293 int64_t numGatherDims = dims.size();
1294 if (numGatherDims > rank)
1296 <<
"_dims overflow " << sourceOrDest <<
" rank";
1297 for (int64_t val : dims) {
1300 <<
"_dims value must be non-negative";
1303 <<
"_dims value must be smaller than " << sourceOrDest <<
" rank";
1305 for (int64_t i = 1; i < numGatherDims; ++i) {
1306 if (dims[i - 1] >= dims[i])
1308 <<
"_dims values must be strictly increasing";
1314 int64_t sourceRank = getSourceType().getRank();
1317 "gather",
"source")))
1320 RankedTensorType expectedResultType = GatherOp::inferResultType(
1321 getSourceType(), getIndicesType(), gatherDims,
false);
1322 RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
1323 getSourceType(), getIndicesType(), gatherDims,
true);
1324 if (getResultType() != expectedResultType &&
1325 getResultType() != expectedRankReducedResultType) {
1326 return emitOpError(
"result type "
1329 << expectedResultType <<
" or its rank-reduced variant "
1330 << expectedRankReducedResultType <<
" (got: " << getResultType()
1339 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1340 getResult().getType()))
1341 return reshapedSource;
1349 void InsertOp::getAsmResultNames(
1351 setNameFn(getResult(),
"inserted");
1356 auto destType = llvm::cast<RankedTensorType>(getDest().getType());
1357 if (destType.getRank() !=
static_cast<int64_t
>(
getIndices().size()))
1358 return emitOpError(
"incorrect number of indices");
1366 if (
auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
1367 if (scalar == splatDest.getSplatValue<
Attribute>())
1376 void GenerateOp::getAsmResultNames(
1378 setNameFn(getResult(),
"generated");
1385 for (
auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
1386 if (getType().isDynamicDim(dim)) {
1387 reifiedReturnShapes[0][dim] = getOperand(idx++);
1389 reifiedReturnShapes[0][dim] =
1399 RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1400 if (getNumOperands() != resultType.getNumDynamicDims())
1401 return emitError(
"must have as many index operands as dynamic extents "
1402 "in the result type");
1407 RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1409 if (!llvm::all_of(getBody().getArgumentTypes(),
1411 return emitError(
"all body arguments must be index");
1412 if (getBody().getNumArguments() != resultTy.getRank())
1413 return emitError(
"must have one body argument per input dimension");
1416 auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
1418 if (yieldOp.getValue().getType() != resultTy.getElementType())
1420 "body must be terminated with a `yield` operation of the tensor "
1426 void GenerateOp::build(
1430 build(b, result, resultTy, dynamicExtents);
1435 auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
1439 b.
createBlock(bodyRegion, bodyRegion->
end(), argumentTypes, argumentLocs);
1456 generateOp.getType(), generateOp.getDynamicExtents(),
1457 foldedDynamicSizes);
1460 if (foldedTensorType == generateOp.getType())
1463 auto loc = generateOp.getLoc();
1465 rewriter.
create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
1467 newOp.getBody().begin());
1469 generateOp.getType(), newOp);
1485 struct ExtractFromTensorGenerate :
public OpRewritePattern<tensor::ExtractOp> {
1490 auto tensorFromElements = extract.getTensor().
getDefiningOp<GenerateOp>();
1495 Block *body = &tensorFromElements.getBody().
front();
1498 rewriter.
clone(op, mapping);
1512 results.
add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
1519 void RankOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
1520 setNameFn(getResult(),
"rank");
1525 auto type = getOperand().getType();
1526 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1527 if (shapedType && shapedType.hasRank())
1529 return IntegerAttr();
1536 void ReshapeOp::getAsmResultNames(
1538 setNameFn(getResult(),
"reshape");
1542 int64_t numElements = 1;
1543 for (
auto dim : type.getShape())
1549 TensorType operandType = llvm::cast<TensorType>(getSource().getType());
1550 TensorType resultType = llvm::cast<TensorType>(getResult().getType());
1553 return emitOpError(
"element types of source and destination tensor "
1554 "types should be the same");
1557 llvm::cast<RankedTensorType>(
getShape().getType()).getDimSize(0);
1558 auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
1559 auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
1561 if (resultRankedType) {
1562 if (operandRankedType && resultRankedType.hasStaticShape() &&
1563 operandRankedType.hasStaticShape()) {
1565 return emitOpError(
"source and destination tensor should have the "
1566 "same number of elements");
1568 if (ShapedType::isDynamic(shapeSize))
1569 return emitOpError(
"cannot use shape operand with dynamic length to "
1570 "reshape to statically-ranked tensor type");
1571 if (shapeSize != resultRankedType.getRank())
1573 "length of shape operand differs from the result's tensor rank");
1580 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
1581 getResult().getType()))
1582 return reshapedSource;
1590 void CollapseShapeOp::getAsmResultNames(
1592 setNameFn(getResult(),
"collapsed");
1595 void ExpandShapeOp::getAsmResultNames(
1597 setNameFn(getResult(),
"expanded");
1600 int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
1601 assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
1602 "invalid resultDim");
1604 if (llvm::is_contained(it.value(), resultDim))
1606 llvm_unreachable(
"could not find reassociation group");
1614 getReassociationIndices());
1622 getReassociationIndices());
1625 RankedTensorType CollapseShapeOp::inferCollapsedType(
1627 return inferCollapsedType(
1629 type.getContext(), reassociation)));
1635 CollapseShapeOp::inferCollapsedType(RankedTensorType type,
1637 auto shape = type.getShape();
1639 newShape.reserve(reassociation.size());
1644 unsigned currentDim = 0;
1646 unsigned dim = m.getNumResults();
1647 auto band = shape.slice(currentDim, dim);
1649 if (llvm::is_contained(band, ShapedType::kDynamic))
1650 size = ShapedType::kDynamic;
1652 for (
unsigned d = 0; d < dim; ++d)
1653 size *= shape[currentDim + d];
1654 newShape.push_back(size);
1664 auto resultType = inferCollapsedType(
1665 llvm::cast<RankedTensorType>(src.
getType()),
1668 build(b, result, resultType, src, attrs);
1673 template <
typename TensorReshapeOp,
bool isExpansion = std::is_same<
1674 TensorReshapeOp, ExpandShapeOp>::value>
1676 RankedTensorType expandedType,
1677 RankedTensorType collapsedType) {
1682 auto maps = op.getReassociationMaps();
1683 RankedTensorType expectedType =
1684 CollapseShapeOp::inferCollapsedType(expandedType, maps);
1686 return op.
emitOpError(
"expected collapsed type to be ")
1687 << expectedType <<
", but got " << collapsedType;
1702 template <
typename TensorReshapeOp>
1713 reshapeOp.getResultType(), attr.
getRawData());
1720 template <
typename TensorReshapeOp>
1727 auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
1728 if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
1732 reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
1739 template <
typename TensorReshapeOp>
1745 reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
1749 auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
1751 if (!shapedTy.hasStaticShape())
1755 fromElements.getElements());
1764 LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
1766 auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
1770 RankedTensorType srcType =
1771 llvm::cast<RankedTensorType>(castOp.getSource().getType());
1772 RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
1773 srcType, collapseShapeOp.getReassociationMaps());
1775 if (newResultType == collapseShapeOp.getResultType()) {
1777 collapseShapeOp.getSrcMutable().assign(castOp.getSource());
1780 auto newOp = rewriter.
create<CollapseShapeOp>(
1781 collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
1782 collapseShapeOp.getReassociation());
1784 collapseShapeOp, collapseShapeOp.getResultType(), newOp);
1795 auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
1800 std::optional<int64_t> dim = dimOp.getConstantIndex();
1801 if (!dim.has_value())
1805 RankedTensorType resultType = expandShapeOp.getResultType();
1806 if (!resultType.isDynamicDim(*dim))
1810 int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);
1816 for (int64_t d : grp) {
1818 assert(!resultType.isDynamicDim(d) &&
"expected static dim");
1819 product *= resultType.getDimSize(d);
1825 rewriter.
create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
1829 dimOp, expr.floorDiv(
product), srcDimSz);
1839 auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
1840 if (!collapseShapeOp)
1844 std::optional<int64_t> dim = dimOp.getConstantIndex();
1845 if (!dim.has_value())
1849 RankedTensorType resultType = collapseShapeOp.getResultType();
1850 if (!resultType.isDynamicDim(*dim))
1855 collapseShapeOp.getReassociationIndices()[*dim];
1862 srcDimSizes.push_back(rewriter.
create<DimOp>(
1863 dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
1878 FoldReshapeWithConstant<ExpandShapeOp>,
1879 FoldReshapeWithSplat<ExpandShapeOp>,
1880 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
1881 FoldDimOfCollapseShape>(context);
1889 FoldReshapeWithConstant<CollapseShapeOp>,
1890 FoldReshapeWithSplat<CollapseShapeOp>,
1891 FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
1895 OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
1896 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*
this,
1897 adaptor.getOperands());
1900 OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
1901 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*
this,
1902 adaptor.getOperands());
1909 void ExtractSliceOp::getAsmResultNames(
1911 setNameFn(getResult(),
"extracted_slice");
1917 RankedTensorType ExtractSliceOp::inferResultType(
1923 assert(
static_cast<int64_t
>(staticSizes.size()) ==
1924 sourceTensorType.getRank() &&
1925 "unexpected staticSizes not equal to rank of source");
1929 RankedTensorType ExtractSliceOp::inferResultType(
1937 return ExtractSliceOp::inferResultType(sourceTensorType, staticOffsets,
1938 staticSizes, staticStrides);
1949 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1950 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1954 auto inferredType = llvm::cast<RankedTensorType>(
1955 inferResultType(sourceRankedTensorType, offsets, sizes, strides));
1956 int rankDiff = inferredType.getRank() - desiredResultRank;
1958 auto shape = inferredType.getShape();
1959 llvm::SmallBitVector dimsToProject =
1963 for (
unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1964 if (!dimsToProject.test(pos))
1965 projectedShape.push_back(shape[pos]);
1969 return inferredType;
1972 RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
1973 unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
1981 return ExtractSliceOp::inferCanonicalRankReducedResultType(
1982 desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1989 RankedTensorType resultType,
Value source,
1999 auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.
getType());
2002 resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
2003 sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
2005 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2019 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2028 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2034 RankedTensorType resultType,
Value source,
2043 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
2050 build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
2055 RankedTensorType expectedType) {
2060 return op->
emitError(
"expected rank to be smaller or equal to ")
2061 <<
"the other rank. ";
2063 return op->
emitError(
"expected type to be ")
2064 << expectedType <<
" or a rank-reduced version. (size mismatch) ";
2066 return op->
emitError(
"expected element type to be ")
2067 << expectedType.getElementType();
2069 llvm_unreachable(
"unexpected extract_slice op verification result");
2076 RankedTensorType expectedType = ExtractSliceOp::inferResultType(
2077 getSourceType(), getMixedOffsets(),
getMixedSizes(), getMixedStrides());
2089 auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.
getType());
2090 assert(sourceTensorType &&
"not a ranked tensor type");
2091 auto sourceShape = sourceTensorType.getShape();
2092 if (sourceShape.equals(desiredShape))
2094 auto maybeRankReductionMask =
2096 if (!maybeRankReductionMask)
2105 reifiedReturnShapes.resize(1);
2106 reifiedReturnShapes[0].reserve(getType().getRank());
2109 for (
const auto &size :
enumerate(mixedSizes)) {
2110 if (droppedDims.test(size.index()))
2112 reifiedReturnShapes[0].push_back(size.value());
2133 class ExtractSliceOpCastFolder final :
public OpRewritePattern<ExtractSliceOp> {
2140 if (llvm::any_of(sliceOp.getOperands(), [](
Value operand) {
2141 return matchPattern(operand, matchConstantIndex());
2145 auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
2154 Value newResult = rewriter.
create<ExtractSliceOp>(
2155 loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
2156 sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
2157 sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
2158 if (newResult.
getType() != sliceOp.getType())
2159 newResult = rewriter.
create<CastOp>(loc, sliceOp.getType(), newResult);
2168 template <
typename IterTy,
typename ElemTy>
2173 assert(offsets.size() == sizes.size());
2174 assert(offsets.size() == strides.size());
2175 if (offsets.empty())
2178 int64_t offset = offsets.front();
2179 int64_t size = sizes.front();
2180 int64_t stride = strides.front();
2181 if (offsets.size() == 1) {
2182 for (int64_t i = 0; i < size; ++i, offset += stride)
2183 outValues->push_back(*(values + offset));
2188 for (int64_t i = 0; i < size; ++i, offset += stride) {
2189 auto begin = values + offset * counts.front();
2190 sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
2191 offsets.drop_front(), sizes.drop_front(),
2192 strides.drop_front(), outValues);
2199 class ConstantOpExtractSliceFolder final
2204 ConstantOpExtractSliceFolder(
MLIRContext *context,
2207 controlFn(std::move(controlFn)) {}
2220 auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
2222 if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
2229 int64_t count = sourceType.getNumElements();
2234 auto offsets = op.getStaticOffsets();
2235 if (llvm::is_contained(offsets, ShapedType::kDynamic))
2237 auto sizes = op.getStaticSizes();
2238 if (llvm::is_contained(sizes, ShapedType::kDynamic))
2240 auto strides = op.getStaticStrides();
2241 if (llvm::is_contained(strides, ShapedType::kDynamic))
2247 counts.reserve(shape.size());
2248 for (int64_t v : shape) {
2250 counts.push_back(count);
2256 if (
auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
2258 outValues.reserve(sourceType.getNumElements());
2259 sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
2260 elems.begin(), counts, offsets, sizes, strides, &outValues);
2262 }
else if (
auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
2264 outValues.reserve(sourceType.getNumElements());
2265 sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
2266 elems.begin(), counts, offsets, sizes, strides, &outValues);
2289 patterns.
add<ConstantOpExtractSliceFolder>(patterns.
getContext(), controlFn);
2298 return ExtractSliceOp::inferCanonicalRankReducedResultType(
2299 op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
2307 ExtractSliceOp newOp) {
2308 Value replacement = newOp.getResult();
2309 if (replacement.
getType() != op.getType())
2310 replacement = rewriter.
create<tensor::CastOp>(op.
getLoc(), op.getType(),
2321 ExtractSliceOpCastFolder>(context);
2327 ShapedType shapedType) {
2334 auto shape = shapedType.getShape();
2335 for (
auto it : llvm::zip(op.getMixedSizes(), shape))
2349 auto insertOp = extractOp.getSource().
getDefiningOp<InsertSliceOp>();
2352 if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
2353 insertOp.isSameAs(extractOp, isSame))
2354 return insertOp.getSource();
2359 OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
2361 llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
2362 getResult().getType()))
2363 return reshapedSource;
2364 if (getSourceType() == getType() &&
2366 return this->getSource();
2375 auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.
getType());
2376 unsigned rank = rankedTensorType.getRank();
2380 return b.
createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
2381 offsets, sizes, strides);
2388 void InsertSliceOp::getAsmResultNames(
2390 setNameFn(getResult(),
"inserted_slice");
2404 build(b, result, dest.
getType(), source, dest, dynamicOffsets, dynamicSizes,
2417 build(b, result, source, dest, offsets, sizes, strides, attrs);
2430 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
2436 RankedTensorType srcType, RankedTensorType dstType,
2441 RankedTensorType expected = ExtractSliceOp::inferResultType(
2442 dstType, staticOffsets, staticSizes, staticStrides);
2444 *expectedType = expected;
2450 RankedTensorType expectedType;
2453 getStaticSizes(), getStaticStrides(), &expectedType);
2475 auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
2478 if (!prevInsertOp ||
2479 prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
2480 !prevInsertOp.isSameAs(insertOp, isSame))
2483 insertOp.getDestMutable().assign(prevInsertOp.getDest());
2495 auto extractOp = insertOp.getSource().
getDefiningOp<ExtractSliceOp>();
2498 if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
2499 !extractOp.isSameAs(insertOp, isSame))
2502 return extractOp.getSource();
2506 if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
2507 getSourceType() == getType() &&
2509 return this->getSource();
2528 template <
typename InsertOpTy>
2529 class InsertSliceOpConstantArgumentFolder final
2547 auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
2548 insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
2549 mixedOffsets, mixedSizes, mixedStrides);
2550 Value toInsert = insertSliceOp.getSource();
2551 if (sourceType != insertSliceOp.getSourceType()) {
2556 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2558 toInsert = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2559 sourceType, toInsert);
2562 insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
2563 mixedSizes, mixedStrides);
2588 template <
typename InsertOpTy>
2589 struct InsertSliceOpCastFolder final :
public OpRewritePattern<InsertOpTy> {
2594 if (llvm::any_of(insertSliceOp.getOperands(), [](
Value operand) {
2595 return matchPattern(operand, matchConstantIndex());
2599 auto getSourceOfCastOp = [](
Value v) -> std::optional<Value> {
2600 auto castOp = v.getDefiningOp<tensor::CastOp>();
2602 return std::nullopt;
2603 return castOp.getSource();
2605 std::optional<Value> sourceCastSource =
2606 getSourceOfCastOp(insertSliceOp.getSource());
2607 std::optional<Value> destCastSource =
2608 getSourceOfCastOp(insertSliceOp.getDest());
2609 if (!sourceCastSource && !destCastSource)
2613 (sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
2614 auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
2615 auto srcType = llvm::dyn_cast<RankedTensorType>(src.
getType());
2616 auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
2617 if (!srcType || !dstType)
2620 insertSliceOp.getStaticSizes(),
2621 insertSliceOp.getStaticStrides()) !=
2626 insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
2627 insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
2630 bool isParallelInsert =
2631 std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
2632 if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
2633 replacement = rewriter.
create<tensor::CastOp>(insertSliceOp.getLoc(),
2634 insertSliceOp.getDestType(),
2663 template <
typename InsertOpTy>
2664 struct InsertSliceOpSourceCastInserter final
2670 RankedTensorType srcType = insertSliceOp.getSourceType();
2671 if (srcType.getRank() != insertSliceOp.getDestType().getRank())
2674 srcType.getShape().end());
2675 for (int64_t i = 0; i < srcType.getRank(); ++i) {
2676 if (std::optional<int64_t> constInt =
2681 newSrcShape[i] = *constInt;
2688 newSrcShape, srcType.getElementType(), srcType.getEncoding());
2689 if (srcType == newSrcType ||
2691 !tensor::CastOp::areCastCompatible(srcType, newSrcType))
2703 if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
2706 insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
2708 insertSliceOp, cast, insertSliceOp.getDest(),
2709 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
2710 insertSliceOp.getMixedStrides());
2722 results.
add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
2723 InsertSliceOpCastFolder<InsertSliceOp>,
2724 InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
2731 auto rankedTensorType = llvm::cast<RankedTensorType>(dest.
getType());
2732 unsigned rank = rankedTensorType.getRank();
2736 return b.
createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
2745 setNameFn(getResult(),
"padded");
2751 Type typeToInfer,
Type typeToInferFrom) {}
2755 std::optional<OpAsmParser::UnresolvedOperand> optOperand,
2756 Type &typeToInfer,
Type typeToInferFrom) {
2758 typeToInfer = typeToInferFrom;
2763 auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
2764 auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
2766 PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
2767 if (!expectedType) {
2768 return emitError(
"failed to infer expectedType from sourceType ")
2769 << sourceType <<
", specified resultType is " << resultType;
2771 if (resultType.getRank() != expectedType.getRank()) {
2773 << resultType <<
" does not match the inferred type "
2776 for (
int i = 0, e = sourceType.getRank(); i < e; ++i) {
2777 if (resultType.getDimSize(i) == expectedType.getDimSize(i))
2779 if (expectedType.isDynamicDim(i))
2782 << resultType <<
" does not match the inferred type "
2790 auto ®ion = getRegion();
2791 unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
2794 return emitError(
"expected the block to have ") << rank <<
" arguments";
2798 if (!en.value().isIndex())
2799 return emitOpError(
"expected block argument ")
2800 << (en.index() + 1) <<
" to be an index";
2805 if (yieldOp.getValue().getType() !=
2807 return emitOpError(
"expected yield type to match shape element type");
2812 RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
2816 unsigned rank = sourceType.getRank();
2817 if (staticLow.size() != rank)
2818 return RankedTensorType();
2819 if (staticHigh.size() != rank)
2820 return RankedTensorType();
2821 if (!resultShape.empty() && resultShape.size() != rank)
2822 return RankedTensorType();
2825 for (
auto i : llvm::seq<unsigned>(0, rank)) {
2826 if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
2827 staticHigh[i] == ShapedType::kDynamic) {
2828 inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
2831 int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
2832 assert((resultShape.empty() || size == resultShape[i] ||
2833 resultShape[i] == ShapedType::kDynamic) &&
2834 "mismatch between inferred shape and result shape");
2835 inferredShape.push_back(size);
2846 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2848 resultType = inferResultType(sourceType, staticLow, staticHigh);
2849 build(b, result, resultType, source, low, high,
2858 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2859 unsigned rank = sourceType.getRank();
2861 build(b, result, resultType, source, staticVector, staticVector, low, high,
2869 auto sourceType = llvm::cast<RankedTensorType>(source.
getType());
2879 resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
2881 assert(llvm::isa<RankedTensorType>(resultType));
2882 build(b, result, resultType, source, dynamicLow, dynamicHigh,
2892 build(b, result, resultType, source, low, high, nofold, attrs);
2896 int sourceRank = llvm::cast<RankedTensorType>(source.
getType()).getRank();
2903 b.
createBlock(region, region->
end(), blockArgTypes, blockArgLocs);
2907 llvm::SmallBitVector PadOp::getPaddedDims() {
2908 llvm::SmallBitVector paddedDims(getSourceType().getRank());
2910 for (
const auto &en :
enumerate(paddingWidths))
2912 paddedDims.set(en.index());
2914 extractPaddedDims(getMixedLowPad());
2915 extractPaddedDims(getMixedHighPad());
2927 if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
2929 if (padTensorOp.getNofold())
2932 padTensorOp, padTensorOp.getResult().getType(),
2933 padTensorOp.getSource());
2944 auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
2948 auto newResultType = PadOp::inferResultType(
2949 llvm::cast<RankedTensorType>(castOp.getSource().getType()),
2950 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2951 padTensorOp.getResultType().getShape());
2953 if (newResultType == padTensorOp.getResultType()) {
2955 padTensorOp.getSourceMutable().assign(castOp.getSource());
2958 auto newOp = rewriter.
create<PadOp>(
2959 padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
2960 padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
2961 padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
2964 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
2967 padTensorOp, padTensorOp.getResultType(), newOp);
2980 if (!padTensorOp.getResult().hasOneUse())
2983 dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
2987 tensorCastOp.getDest().getType()))
2990 auto replacementOp = rewriter.
create<PadOp>(
2991 padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
2992 padTensorOp.getSource(), padTensorOp.getStaticLow(),
2993 padTensorOp.getStaticHigh(), padTensorOp.getLow(),
2994 padTensorOp.getHigh(), padTensorOp.getNofold(),
2998 rewriter.
replaceOp(padTensorOp, replacementOp.getResult());
2999 rewriter.
replaceOp(tensorCastOp, replacementOp.getResult());
3044 auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
3047 auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
3048 if (!outerPadOp || outerPadOp.getNofold())
3050 auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
3055 int64_t rank = padOp.getSourceType().getRank();
3056 if (outerSliceOp.getSourceType().getRank() != rank) {
3058 "cannot fold rank-reducing chain");
3062 if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
3064 padOp,
"cannot fold non-unit stride ExtractSliceOps");
3068 if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
3070 "cannot fold PadOps with low padding");
3075 Value innerValue = padOp.getConstantPaddingValue();
3076 Value outerValue = outerPadOp.getConstantPaddingValue();
3077 if (!innerValue || !outerValue ||
3080 innerAttr != outerAttr) {
3082 padOp,
"cannot fold PadOps with different padding values");
3086 llvm::SmallBitVector innerDims = padOp.getPaddedDims();
3087 llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
3088 if (innerDims.anyCommon(outerDims)) {
3090 padOp,
"cannot fold PadOps with common padding dimensions");
3100 OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
3101 OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
3102 if (!innerDims.test(en.index()) &&
3104 en.value() = outerOffset;
3107 if (!outerDims.test(en.index()) &&
3109 en.value() = innerOffset;
3113 padOp,
"cannot find zero-offset and zero-padding pair");
3123 if (!outerDims.test(en.index()))
3125 OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
3126 int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
3127 assert(!ShapedType::isDynamic(sourceSize) &&
3128 "expected padded dimension to have a static size");
3131 padOp,
"cannot fold since the inner ExtractSliceOp size does not "
3132 "match the size of the outer padding");
3134 en.value() = outerSliceOp.getMixedSizes()[en.index()];
3140 if (innerDims.test(en.index()))
3141 newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
3142 if (outerDims.test(en.index()))
3143 newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
3148 auto newSliceOp = rewriter.
create<ExtractSliceOp>(
3149 padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
3150 innerSliceOp.getMixedStrides());
3151 auto newPadOp = rewriter.
create<PadOp>(
3152 padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
3153 padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
3156 newPadOp.getRegion().begin());
3157 rewriter.
replaceOp(padOp, newPadOp.getResult());
3167 Value input = padTensorOp.getSource();
3168 if (!llvm::isa<RankedTensorType>(input.
getType()))
3170 auto inputDims = llvm::cast<RankedTensorType>(input.
getType()).getShape();
3171 auto inputRank = inputDims.size();
3173 auto oldResultType =
3174 dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
3178 auto outputDims = oldResultType.getShape();
3183 for (
auto operand : padTensorOp.getLow()) {
3186 constOperandsLow.push_back(ShapedType::kDynamic);
3187 newLows.push_back(operand);
3190 constOperandsLow.push_back(intOp.getExtValue());
3194 for (
auto operand : padTensorOp.getHigh()) {
3197 constOperandsHigh.push_back(ShapedType::kDynamic);
3198 newHighs.push_back(operand);
3201 constOperandsHigh.push_back(intOp.getExtValue());
3208 if (inputDims.size() != outputDims.size() ||
3209 inputDims.size() != constLow.size() ||
3210 inputDims.size() != constHigh.size())
3215 for (
size_t i = 0; i < inputRank; i++) {
3216 if (constLow[i] == ShapedType::kDynamic)
3217 constLow[i] = constOperandsLow[lowCount++];
3218 if (constHigh[i] == ShapedType::kDynamic)
3219 constHigh[i] = constOperandsHigh[highCount++];
3227 for (
size_t i = 0; i < inputRank; i++) {
3228 if (outputDims[i] == ShapedType::kDynamic) {
3229 newOutDims.push_back(
3230 (staticLow[i] == ShapedType::kDynamic ||
3231 staticHigh[i] == ShapedType::kDynamic ||
3232 inputDims[i] == ShapedType::kDynamic
3233 ? ShapedType::kDynamic
3234 : inputDims[i] + staticLow[i] + staticHigh[i]));
3236 newOutDims.push_back(outputDims[i]);
3241 llvm::all_of(newOutDims,
3242 [&](int64_t x) {
return x == ShapedType::kDynamic; }))
3247 newOutDims, padTensorOp.getType().getElementType());
3248 auto newOp = rewriter.
create<PadOp>(
3249 padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3250 newLows, newHighs, padTensorOp.getNofold(),
3254 padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
3266 results.
add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
3267 FoldOrthogonalPaddings, FoldStaticPadding>(context);
3279 Value PadOp::getConstantPaddingValue() {
3280 auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
3283 Value padValue = yieldOp.getValue();
3295 if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
3305 OpResult ParallelInsertSliceOp::getTiedOpResult() {
3306 ParallelCombiningOpInterface parallelCombiningParent =
3307 getParallelCombiningParent();
3308 for (
const auto &it :
3311 if (&nextOp == getOperation())
3312 return parallelCombiningParent.getParentResult(it.index());
3314 llvm_unreachable(
"ParallelInsertSliceOp no tied OpResult found");
3329 build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
3343 build(b, result, source, dest, offsets, sizes, strides, attrs);
3357 build(b, result, source, dest, offsetValues, sizeValues, strideValues);
3361 if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
3362 return this->
emitError(
"expected ParallelCombiningOpInterface parent, got:")
3363 << *(getOperation()->getParentOp());
3365 RankedTensorType expectedType;
3368 getStaticSizes(), getStaticStrides(), &expectedType);
3372 void ParallelInsertSliceOp::getCanonicalizationPatterns(
3374 results.
add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
3375 InsertSliceOpCastFolder<ParallelInsertSliceOp>,
3376 InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
3387 void ScatterOp::getAsmResultNames(
3389 setNameFn(getResult(),
"scatter");
3393 int64_t destRank = getDestType().getRank();
3396 "scatter",
"dest")))
3400 return emitOpError(
"requires 'unique' attribute to be set");
3407 RankedTensorType expectedSourceType = GatherOp::inferResultType(
3408 getDestType(), getIndicesType(), scatterDims,
false);
3409 RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
3410 getDestType(), getIndicesType(), scatterDims,
true);
3411 if (getSourceType() != expectedSourceType &&
3412 getSourceType() != expectedRankReducedSourceType) {
3413 return emitOpError(
"source type "
3416 << expectedSourceType <<
" or its rank-reduced variant "
3417 << expectedRankReducedSourceType <<
" (got: " << getSourceType()
3430 build(builder, result, aggregateType, element, dynamicSizes);
3436 build(builder, result, aggregateType, element, dynamicSizes);
3444 build(builder, result, element, staticShape, dynamicSizes);
3447 void SplatOp::getAsmResultNames(
3449 setNameFn(getResult(),
"splat");
3453 if (getType().getNumDynamicDims() !=
3455 return emitOpError(
"incorrect number of dynamic sizes, has ")
3457 << getType().getNumDynamicDims();
3466 for (int64_t i = 0; i < getType().getRank(); ++i) {
3467 if (getType().isDynamicDim(i)) {
3470 reifiedReturnShapes[0][i] = builder.
getIndexAttr(getType().getDimSize(i));
3477 auto constOperand = adaptor.getInput();
3478 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
3482 if (!getType().hasStaticShape())
3494 template <
typename OpTy>
3498 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3499 "applies to only pack or unpack operations");
3500 int64_t destRank = op.getDestRank();
3502 reifiedReturnShapes[0] =
3507 template <
typename OpTy>
3509 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3510 "applies to only pack or unpack operations");
3514 assert(tiles.size() == dimsToTile.size() &&
3515 "tiles must match indices of dimension to block");
3517 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
3518 dimAndTileMapping[dimsToTile[i]] = tiles[i];
3519 return dimAndTileMapping;
3522 template <
typename OpTy>
3524 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3525 "applies to only pack or unpack operations");
3528 unsigned dynamicValIndex = 0;
3529 for (int64_t staticTile : op.getStaticInnerTiles()) {
3530 if (!ShapedType::isDynamic(staticTile))
3533 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
3535 return mixedInnerTiles;
3538 template <
typename OpTy>
3540 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3541 "applies to only pack or unpack operations");
3554 size_t dimsPosSize = dimsPos.size();
3555 if (dimsPosSize > rank)
3558 for (int64_t dim : dimsPos)
3559 uniqued.insert(dim);
3560 if (dimsPosSize != uniqued.size())
3562 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
3563 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
3572 sourceShape.size() == limitShape.size() &&
3573 "expected source shape rank, and limit of the shape to have same rank");
3574 return llvm::all_of(
3575 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
3576 int64_t sourceExtent = std::get<0>(it);
3577 int64_t limit = std::get<1>(it);
3578 return ShapedType::isDynamic(sourceExtent) ||
3579 ShapedType::isDynamic(limit) || sourceExtent <= limit;
3583 template <
typename OpTy>
3585 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3586 "applies to only pack or unpack operations");
3587 Operation *op = packOrUnPack.getOperation();
3591 return llvm::any_of(
3597 if (hasZeros(mixedTiles))
3598 return op->
emitError(
"invalid zero tile factor");
3601 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
3602 ? packOrUnPack.getSourceType()
3603 : packOrUnPack.getDestType();
3604 size_t unpackedRank = unpackedType.getRank();
3608 return op->
emitError(
"invalid inner_dims_pos vector");
3610 return op->
emitError(
"invalid outer_dims_perm vector");
3611 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
3612 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
3616 if (mixedTiles.size() > unpackedRank) {
3617 return op->
emitError(
"tiling factors must be less than or equal to the "
3618 "input rank for pack or output rank for unpack");
3620 if (mixedTiles.size() != innerDimsPos.size()) {
3622 "tiling factors must equal the number of dimensions to tile");
3625 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3626 ? packOrUnPack.getDestType()
3627 : packOrUnPack.getSourceType();
3628 size_t packedRank = packedType.getRank();
3630 if (unpackedRank + mixedTiles.size() != packedRank) {
3632 "packed rank must equal unpacked rank + tiling factors");
3638 RankedTensorType expectedPackedType = PackOp::inferPackedType(
3639 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
3640 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
3641 return op->
emitError(
"the shape of output is not large enough to hold the "
3642 "packed data. Expected at least ")
3643 << expectedPackedType <<
", got " << packedType;
3646 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
3648 [](std::tuple<int64_t, OpFoldResult> it) {
3649 std::optional<int64_t> constTileSize =
3650 getConstantIntValue(std::get<1>(it));
3651 int64_t shape = std::get<0>(it);
3652 if (!constTileSize) {
3655 return ShapedType::isDynamic(shape);
3657 if (ShapedType::isDynamic(shape)) {
3664 return shape == constTileSize.value();
3666 return op->
emitError(
"mismatch in inner tile sizes specified and shaped of "
3667 "tiled dimension in the packed type");
3679 struct PackOrUnPackTransposeResult {
3686 template <
typename OpTy>
3687 static PackOrUnPackTransposeResult
3691 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3692 "applies to only pack or unpack operations");
3693 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
3694 "some permutation must be non-empty");
3695 PackOrUnPackTransposeResult metadata;
3696 metadata.innerDimsPos =
3698 metadata.innerTiles =
3700 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
3701 ? packOrUnPackOp.getSourceRank()
3702 : packOrUnPackOp.getDestRank();
3703 metadata.outerDimsPerm =
3704 packOrUnPackOp.getOuterDimsPerm().empty()
3705 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
3707 if (!innerPermutation.empty()) {
3708 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
3710 "invalid inner permutation");
3714 if (!outerPermutation.empty()) {
3715 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
3717 "invalid outer permutation");
3727 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
3728 setNameFn(getResult(),
"pack");
3734 std::optional<Value> paddingValue,
3736 assert(innerDimsPos.size() == innerTiles.size() &&
3737 "number of tile sizes specified must match the specified number of "
3738 "original dimensions to be tiled");
3742 build(builder, state, dest.
getType(), source, dest,
3743 paddingValue ? *paddingValue :
nullptr,
3744 outerDimsPerm.empty() ?
nullptr
3774 outputShape.take_front(inputShape.size()));
3775 if (!outerDimsPerm.empty()) {
3776 assert(outerDimsPerm.size() == outputTileSizes.size() &&
3777 "expected output and outer_dims_perm to have same size");
3781 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
3782 if (ShapedType::isDynamic(inputShape[pos]))
3786 if (!constantTile) {
3787 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
3788 (inputShape[pos] % outputTileSizes[pos] != 0))
3790 }
else if (inputShape[pos] % (*constantTile) != 0) {
3804 auto paddingValue = getPaddingValue();
3807 return emitOpError(
"expected padding_value has ")
3808 << getSourceType().getElementType()
3809 <<
" but got: " << paddingValue.getType();
3812 if (!paddingValue &&
3813 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
3814 getDestType().
getShape(), getOuterDimsPerm(),
3817 "invalid tile factor or output size provided. Only full tiles are "
3818 "supported when padding_value is not set");
3828 for (
auto o : ofrs) {
3830 if (llvm::dyn_cast_if_present<Value>(o))
3831 result.push_back(ShapedType::kDynamic);
3846 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
3848 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
3849 resultShape[tiledDim.value()] = ShapedType::kDynamic;
3852 resultShape[tiledDim.value()] =
ceilDiv(resultShape[tiledDim.value()],
3853 innerTileSizes[tiledDim.index()]);
3857 if (!outerDimsPerm.empty())
3861 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
3876 builder, loc, ceilDivExpr,
3877 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
3879 if (!outerDimsPerm.empty())
3881 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
3886 innerDimsPos, outerDimsPerm);
3892 for (
unsigned i = 0; i < resultDims.size(); ++i) {
3893 if (!ShapedType::isDynamic(resultTypeShape[i]))
3904 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
3909 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
3926 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
3927 if (ShapedType::isDynamic(value))
3928 mixedSizes.push_back(b.
create<DimOp>(loc, source, index).
getResult());
3932 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
3933 int64_t dimPos = std::get<0>(it);
3935 mixedSizes[dimPos] =
ceilDiv(mixedSizes[dimPos], tileSize);
3937 if (!outerDimsPerm.empty())
3938 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
3940 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
3941 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
3942 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
3949 *
this, innerPermutation, outerPermutation);
3950 Value transposedDest =
3951 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
3952 metadata.innerDimsPos, metadata.outerDimsPerm);
3953 return b.
create<PackOp>(loc, getSource(), transposedDest,
3954 metadata.innerDimsPos, metadata.innerTiles,
3955 getPaddingValue(), metadata.outerDimsPerm);
3959 template <
typename OpTy>
3961 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
3962 "applies to only pack or unpack operations");
3963 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
3965 : op.getSourceType();
3967 for (
auto [dimDest,
tile] : llvm::zip(
3968 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
3970 if (!constTileSize || ShapedType::isDynamic(dimDest))
3977 if (getPaddingValue())
3992 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
3994 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm();
4000 auto packTiles = packOp.getMixedTiles();
4001 auto unPackTiles = unPackOp.getMixedTiles();
4002 if (packTiles.size() != unPackTiles.size())
4004 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4013 auto srcType = op.getSourceType();
4014 if (llvm::any_of(op.getInnerDimsPos(),
4015 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4017 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4019 return !PackOp::requirePaddingValue(
4020 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4021 op.getOuterDimsPerm(), op.getMixedTiles());
4028 bool changeNeeded =
false;
4029 srcShape.assign(packOp.getSourceType().getShape().begin(),
4030 packOp.getSourceType().getShape().end());
4031 destShape.assign(packOp.getDestType().getShape().begin(),
4032 packOp.getDestType().getShape().end());
4033 llvm::SmallSetVector<int64_t, 4> innerDims;
4034 innerDims.insert(packOp.getInnerDimsPos().begin(),
4035 packOp.getInnerDimsPos().end());
4037 if (!packOp.getOuterDimsPerm().empty())
4039 int srcRank = packOp.getSourceRank();
4040 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4041 if (innerDims.contains(i))
4044 int64_t destPos = i;
4045 if (!inverseOuterDimsPerm.empty())
4046 destPos = inverseOuterDimsPerm[srcPos];
4047 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4048 ShapedType::isDynamic(destShape[destPos])) {
4051 int64_t size = srcShape[srcPos];
4052 if (ShapedType::isDynamic(size))
4053 size = destShape[destPos];
4054 srcShape[srcPos] = size;
4055 destShape[destPos] = size;
4056 changeNeeded =
true;
4058 return changeNeeded;
4063 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4064 if (unPackOp.getSourceType() != packOp.getDestType())
4066 if (packOp.getPaddingValue() ||
4070 rewriter.
replaceOp(packOp, unPackOp.getSource());
4077 packOp.getPaddingValueMutable().clear();
4086 Value source = packOp.getSource();
4087 if (srcShape != packOp.getSourceType().getShape()) {
4088 auto newSrcType = packOp.getSourceType().clone(srcShape);
4090 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4092 Value dest = packOp.getDest();
4093 if (destShape != packOp.getDestType().getShape()) {
4094 auto newDestType = packOp.getDestType().clone(destShape);
4096 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4099 loc, source, dest, packOp.getInnerDimsPos(), packOp.getMixedTiles(),
4100 packOp.getPaddingValue(), packOp.getOuterDimsPerm());
4102 packOp, packOp.getResult().getType(), newOp);
4109 template <
typename PackOrUnpackOp>
4111 RankedTensorType packedTensorType) {
4112 static_assert(std::is_same<PackOrUnpackOp, tensor::PackOp>::value ||
4113 std::is_same<PackOrUnpackOp, tensor::UnPackOp>::value,
4114 "Function meant for pack/unpack");
4119 int64_t numPackedDims = innerDimsPos.size();
4120 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4121 if (orderedDims != innerDimsPos) {
4127 int64_t packedRank = packedTensorType.getRank();
4137 return llvm::all_of(
4138 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4139 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
4142 bool PackOp::isLikePad() {
4143 auto packedTensorType =
4144 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
4149 std::optional<Attribute> paddingValue;
4150 if (
auto pad = adaptor.getPaddingValue())
4153 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4154 getDestType(), paddingValue))
4155 return reshapedSource;
4163 void UnPackOp::getAsmResultNames(
4165 setNameFn(getResult(),
"unpack");
4202 assert(innerDimsPos.size() == innerTiles.size() &&
4203 "number of tile sizes specified must match the specified number of "
4204 "original dimensions to be tiled");
4208 build(builder, state, dest.
getType(), source, dest,
4209 outerDimsPerm.empty() ?
nullptr
4227 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
4229 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
4230 if (srcType.isDynamicDim(i))
4233 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
4235 if (!outerDimsPerm.empty()) {
4236 applyPermutationToVector<OpFoldResult>(
4240 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
4241 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
4243 auto elemType = srcType.getElementType();
4244 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4248 Value transposedSource,
4252 *
this, innerPermutation, outerPermutation);
4253 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
4254 metadata.innerDimsPos, metadata.innerTiles,
4255 metadata.outerDimsPerm);
4262 bool changeNeeded =
false;
4263 srcShape.assign(op.getSourceType().getShape().begin(),
4264 op.getSourceType().getShape().end());
4265 destShape.assign(op.getDestType().getShape().begin(),
4266 op.getDestType().getShape().end());
4267 llvm::SmallSetVector<int64_t, 4> innerDims;
4268 innerDims.insert(op.getInnerDimsPos().begin(), op.getInnerDimsPos().end());
4270 if (!op.getOuterDimsPerm().empty())
4272 int destRank = op.getDestRank();
4273 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
4274 if (innerDims.contains(i))
4277 int64_t destPos = i;
4278 if (!inverseOuterDimsPerm.empty())
4279 srcPos = inverseOuterDimsPerm[destPos];
4280 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4281 ShapedType::isDynamic(destShape[destPos])) {
4284 int64_t size = srcShape[srcPos];
4285 if (ShapedType::isDynamic(size))
4286 size = destShape[destPos];
4287 srcShape[srcPos] = size;
4288 destShape[destPos] = size;
4289 changeNeeded =
true;
4291 return changeNeeded;
4297 if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
4298 if (packOp.getDestType() != unPackOp.getSourceType())
4300 if (packOp.getPaddingValue() ||
4304 rewriter.
replaceOp(unPackOp, packOp.getSource());
4308 if (
auto dstStyleOp =
4309 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
4310 auto destValue = unPackOp.getDest().cast<
OpResult>();
4311 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
4313 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
4321 Value source = unPackOp.getSource();
4322 if (srcShape != unPackOp.getSourceType().getShape()) {
4323 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
4324 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
4325 unPackOp.getSource());
4327 Value dest = unPackOp.getDest();
4328 if (destShape != unPackOp.getDestType().getShape()) {
4329 auto newDestType = unPackOp.getDestType().clone(destShape);
4331 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
4334 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
4335 unPackOp.getOuterDimsPerm());
4337 unPackOp, unPackOp.getResult().getType(), newOp);
4344 bool UnPackOp::isLikeUnPad() {
4345 RankedTensorType packedTensorType = getSourceType();
4351 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
4352 getResult().getType()))
4353 return reshapedSource;
4385 if (isa<InsertSliceOp>(op.getOperation()))
4390 if (isa<LoopLikeOpInterface>(op.getOperation()))
4394 bool hasTensorCastOperand =
4396 if (llvm::isa<BlockArgument>(opOperand.get()))
4398 auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4399 return castOp && canFoldIntoConsumerOp(castOp);
4401 if (!hasTensorCastOperand)
4411 newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.
get());
4412 if (op.isDpsInit(&opOperand) &&
4413 !llvm::isa<MemRefType>(newOperands.back().getType()))
4414 newResultTypes.push_back(newOperands.back().getType());
4418 Operation *newOp =
clone(rewriter, op, newResultTypes, newOperands);
4421 for (
auto [oldResult, newResult] :
4423 if (newResult.
getType() != oldResult.getType()) {
4424 replacements.push_back(rewriter.
create<tensor::CastOp>(
4425 op->
getLoc(), oldResult.getType(), newResult));
4427 replacements.push_back(newResult);
4440 void TensorDialect::getCanonicalizationPatterns(
4449 #define GET_OP_CLASSES
4450 #include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static TensorType joinShapes(TensorType one, TensorType two)
Compute a TensorType that has the joined shape knowledge of the two given TensorTypes.
static LogicalResult verifyGatherOrScatterDims(Operation *op, ArrayRef< int64_t > dims, int64_t rank, StringRef gatherOrScatter, StringRef sourceOrDest)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, Operation *op, RankedTensorType expectedType)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
ParseResult parseInferType(OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > optOperand, Type &typeToInfer, Type typeToInferFrom)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp)
If we have two consecutive InsertSliceOp writing to the same slice, we can mutate the second InsertSl...
static LogicalResult foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op, ShapedType shapedType)
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp)
If we have an ExtractSliceOp consuming an InsertSliceOp with the same slice, we can return the Insert...
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static int64_t getNumElements(ShapedType type)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
static SliceVerificationResult verifyInsertSliceOp(RankedTensorType srcType, RankedTensorType dstType, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, RankedTensorType *expectedType=nullptr)
Rank-reducing type verification for both InsertSliceOp and ParallelInsertSliceOp.
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static OpFoldResult reshapeConstantSource(DenseElementsAttr source, TensorType result, std::optional< Attribute > cst=std::nullopt)
Try to remove a tensor operation if it would only reshape a constant.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, Type typeToInfer, Type typeToInferFrom)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp)
Folds round-trip extract/insert slice op pairs.
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, RankedTensorType expandedType, RankedTensorType collapsedType)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineExpr getAffineSymbolExpr(unsigned position)
IntegerAttr getI64IntegerAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
DenseElementsAttr resizeSplat(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but with a different ...
static DenseElementsAttr getFromRawBuffer(ShapedType type, ArrayRef< char > rawBuffer)
Construct a dense elements attribute from a raw buffer representing the data for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
ArrayRef< char > getRawData() const
Return the raw storage data held by this attribute.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
This is a utility class for mapping one set of IR entities to another.
auto lookupOrDefault(T from) const
Lookup a mapped value within the map.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
Pattern to rewrite dynamic offsets/sizes/strides of view/slice-like ops as constant arguments.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< OpOperand > getOpOperands()
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void takeBody(Region &other)
Takes body of another region (that region will have no body after this operation completes).
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
MPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
LogicalResult foldTensorCast(Operation *op)
Performs folding of any operand of op if it comes from a tensor::CastOp that can be folded.
void populateFoldConstantExtractSlicePatterns(RewritePatternSet &patterns, const ControlConstantExtractSliceFusionFn &controlFn=[](ExtractSliceOp op) { return false;})
Patterns to fold the extract slice op with its constant operand.
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
bool canFoldIntoConsumerOp(CastOp castOp)
Determines whether tensor::CastOp casts to a more dynamic version of the source tensor.
Value createCanonicalRankReducingInsertSliceOp(OpBuilder &b, Location loc, Value tensor, Value dest)
Create a rank-reducing InsertSliceOp @[0 .
Value createCanonicalRankReducingExtractSliceOp(OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType)
Create a rank-reducing ExtractSliceOp @[0 .
bool isSameTypeWithoutEncoding(Type tp1, Type tp2)
Tests if types are the same when ignoring encoding on ranked tensors.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
std::function< bool(ExtractSliceOp)> ControlConstantExtractSliceFusionFn
Function to control the folding of constant and extract slice.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
LogicalResult getOrCreateDestinations(OpBuilder &b, Location loc, Operation *op, SmallVector< Value > &result)
This is a helper function for DestinationStyleOpInterface.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
std::tuple< SmallVector< OpFoldResult >, SmallVector< OpFoldResult >, SmallVector< OpFoldResult > > getOffsetsSizesAndStrides(ArrayRef< Range > ranges)
Given an array of Range values, return a tuple of (offset vector, sizes vector, and strides vector) f...
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
LogicalResult foldDynamicStrideList(SmallVectorImpl< OpFoldResult > &strides)
Returns "success" when any of the elements in strides is a constant value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
int64_t ceilDiv(int64_t lhs, int64_t rhs)
Returns the result of MLIR's ceildiv operation on constants.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
static LogicalResult verifyReshapeLikeTypes(Op op, T expandedType, T collapsedType, bool isExpansion)
Common verifier for reshape-like types.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
bool wouldOpBeTriviallyDead(Operation *op)
Return true if the given operation would be dead if unused, and has no side effects on memory that wo...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
ArrayAttr getReassociationIndicesAttribute(OpBuilder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
bool isReassociationValid(ArrayRef< AffineMap > reassociation, int *invalidIndex=nullptr)
Return true if the reassociation specification is valid, false otherwise.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< NamedAttribute > getPrunedAttributeList(Operation *op, ArrayRef< StringRef > elidedAttrs)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
LogicalResult foldDynamicOffsetSizeList(SmallVectorImpl< OpFoldResult > &offsetsOrSizes)
Returns "success" when any of the elements in offsetsOrSizes is a constant value.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if the tensor....
LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace ExtractSliceOps.
void operator()(PatternRewriter &rewriter, ExtractSliceOp op, ExtractSliceOp newOp)
Return the canonical type of the result of an extract_slice op.
RankedTensorType operator()(ExtractSliceOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Pattern to compose collapse_shape(expand_shape(src, reassociation_1), reassociation_2).
Pattern to collapse producer/consumer reshape ops that are both collapsing dimensions or are both exp...
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Idiomatic saturated operations on values like offsets, sizes, and strides.
static SaturatedInteger wrap(int64_t v)
FailureOr< SaturatedInteger > desaturate(SaturatedInteger other)