41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SetOperations.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/ADT/StringSet.h"
47 #include "llvm/ADT/TypeSwitch.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/LogicalResult.h"
50 #include "llvm/Support/MathExtras.h"
51 #include "llvm/Support/raw_ostream.h"
61 auto type = cast<ShapedType>(v.
getType());
62 if (!type.isDynamicDim(dim))
67 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
68 return builder.create<tensor::DimOp>(loc, v, dim);
70 .Case<MemRefType>([&](MemRefType t) ->
Value {
71 return builder.create<memref::DimOp>(loc, v, dim);
82 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
83 return b.
create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
87 return b.
create<memref::SubViewOp>(loc, source, offsets, sizes,
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
103 llvm_unreachable(
"Expected MemRefType or TensorType");
108 auto shapedType = llvm::cast<ShapedType>(source.
getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
132 for (
auto containers : {inputTypes, outputTypes}) {
133 for (
auto t : containers) {
145 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
149 regionBuilder(b, *body, attrs);
161 std::optional<TypeRange> resultTensorTypes,
168 if (!resultTensorTypes)
169 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
170 llvm::IsaPred<RankedTensorType>);
172 state.addOperands(inputs);
173 state.addOperands(outputs);
174 state.addTypes(derivedResultTypes);
176 state.addAttributes(attributes);
178 "operandSegmentSizes",
180 static_cast<int32_t>(outputs.size())}));
183 Region ®ion = *state.addRegion();
185 state.attributes.getAttrs(), regionBuilder);
189 std::optional<TypeRange> resultTensorTypes,
196 indexingMapsAttrVal = llvm::map_to_vector(
197 MatmulOp::getDefaultIndexingMaps(b.
getContext()),
199 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
201 attributes, regionBuilder);
205 std::optional<TypeRange> resultTensorTypes,
212 indexingMapsAttrVal =
216 state.addAttribute(
"indexing_maps", b.
getArrayAttr(indexingMapsAttrVal));
218 attributes, regionBuilder);
227 bool addOperandSegmentSizes =
true) {
228 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
257 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
259 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
263 if (addOperandSegmentSizes) {
272 attrs.
append(
"operandSegmentSizes",
274 {static_cast<int32_t>(inputsOperands.size()),
275 static_cast<int32_t>(outputsOperands.size())}));
280 {static_cast<int32_t>(inputsOperands.size()),
281 static_cast<int32_t>(outputsOperands.size())}));
285 std::optional<RegisteredOperationName> info =
288 if (failed(info->verifyInherentAttrs(result.
attributes, [&]() {
289 return parser.emitError(attrsLoc)
290 <<
"'" << result.name.getStringRef() <<
"' op ";
301 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
302 if (!outputs.empty())
303 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
314 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
317 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
318 "region expects {0} args, got {1}",
319 numRegionArgs, inputTypes.size() + outputTypes.size()));
338 unsigned numRegionArgs,
354 result.
addTypes(outputTensorsTypes);
356 std::unique_ptr<Region> region = std::make_unique<Region>();
368 if (resultTypes.empty())
413 class RegionBuilderHelper {
416 : builder(builder), block(block) {}
420 if (!isFloatingPoint(arg))
421 llvm_unreachable(
"unsupported non numeric type");
423 builder.setInsertionPointToEnd(&block);
426 return builder.create<math::ExpOp>(arg.
getLoc(), arg);
428 return builder.create<math::LogOp>(arg.
getLoc(), arg);
430 return builder.create<math::AbsFOp>(arg.
getLoc(), arg);
432 return builder.create<math::CeilOp>(arg.
getLoc(), arg);
434 return builder.create<math::FloorOp>(arg.
getLoc(), arg);
436 return builder.create<arith::NegFOp>(arg.
getLoc(), arg);
437 case UnaryFn::reciprocal: {
439 auto one = builder.create<arith::ConstantOp>(arg.
getLoc(),
440 ::cast<TypedAttr>(oneAttr));
441 return builder.create<arith::DivFOp>(arg.
getLoc(), one, arg);
444 return builder.create<math::RoundOp>(arg.
getLoc(), arg);
446 return builder.create<math::SqrtOp>(arg.
getLoc(), arg);
448 return builder.create<math::RsqrtOp>(arg.
getLoc(), arg);
449 case UnaryFn::square:
450 return builder.create<arith::MulFOp>(arg.
getLoc(), arg, arg);
452 return builder.create<math::TanhOp>(arg.
getLoc(), arg);
454 return builder.create<math::ErfOp>(arg.
getLoc(), arg);
456 llvm_unreachable(
"unsupported unary function");
461 bool allComplex = isComplex(arg0) && isComplex(arg1);
462 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
463 bool allInteger = isInteger(arg0) && isInteger(arg1);
466 if (!allComplex && !allFloatingPoint && !allInteger)
467 llvm_unreachable(
"unsupported non numeric type");
469 builder.setInsertionPointToEnd(&block);
473 return builder.create<complex::AddOp>(arg0.
getLoc(), arg0, arg1);
474 if (allFloatingPoint)
475 return builder.create<arith::AddFOp>(arg0.
getLoc(), arg0, arg1);
477 return builder.create<arith::OrIOp>(arg0.
getLoc(), arg0, arg1);
478 return builder.create<arith::AddIOp>(arg0.
getLoc(), arg0, arg1);
481 return builder.create<complex::SubOp>(arg0.
getLoc(), arg0, arg1);
482 if (allFloatingPoint)
483 return builder.create<arith::SubFOp>(arg0.
getLoc(), arg0, arg1);
485 llvm_unreachable(
"unsupported operation: sub with bools");
486 return builder.create<arith::SubIOp>(arg0.
getLoc(), arg0, arg1);
489 return builder.create<complex::MulOp>(arg0.
getLoc(), arg0, arg1);
490 if (allFloatingPoint)
491 return builder.create<arith::MulFOp>(arg0.
getLoc(), arg0, arg1);
493 return builder.create<arith::AndIOp>(arg0.
getLoc(), arg0, arg1);
494 return builder.create<arith::MulIOp>(arg0.
getLoc(), arg0, arg1);
497 return builder.create<complex::DivOp>(arg0.
getLoc(), arg0, arg1);
498 if (allFloatingPoint)
499 return builder.create<arith::DivFOp>(arg0.
getLoc(), arg0, arg1);
501 llvm_unreachable(
"unsupported operation: div with bools");
502 return builder.create<arith::DivSIOp>(arg0.
getLoc(), arg0, arg1);
503 case BinaryFn::div_unsigned:
504 if (!allInteger || allBool)
505 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
506 return builder.create<arith::DivUIOp>(arg0.
getLoc(), arg0, arg1);
507 case BinaryFn::max_signed:
509 if (allFloatingPoint)
510 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
511 return builder.create<arith::MaxSIOp>(arg0.
getLoc(), arg0, arg1);
512 case BinaryFn::min_signed:
514 if (allFloatingPoint)
515 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
516 return builder.create<arith::MinSIOp>(arg0.
getLoc(), arg0, arg1);
517 case BinaryFn::max_unsigned:
519 if (allFloatingPoint)
520 return builder.create<arith::MaximumFOp>(arg0.
getLoc(), arg0, arg1);
521 return builder.create<arith::MaxUIOp>(arg0.
getLoc(), arg0, arg1);
522 case BinaryFn::min_unsigned:
524 if (allFloatingPoint)
525 return builder.create<arith::MinimumFOp>(arg0.
getLoc(), arg0, arg1);
526 return builder.create<arith::MinUIOp>(arg0.
getLoc(), arg0, arg1);
528 assert(allFloatingPoint);
529 return builder.create<math::PowFOp>(arg0.
getLoc(), arg0, arg1);
531 llvm_unreachable(
"unsupported binary function");
539 bool tailFloatingPoint =
540 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
541 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
543 builder.setInsertionPointToEnd(&block);
545 case TernaryFn::select:
546 if (!headBool && !(tailFloatingPoint || tailInteger))
547 llvm_unreachable(
"unsupported non numeric type");
548 return builder.create<arith::SelectOp>(arg0.
getLoc(), arg0, arg1, arg2);
550 llvm_unreachable(
"unsupported ternary function");
556 case TypeFn::cast_signed:
557 return cast(toType, operand,
false);
558 case TypeFn::cast_unsigned:
559 return cast(toType, operand,
true);
561 llvm_unreachable(
"unsupported type conversion function");
566 builder.setInsertionPointToEnd(&block);
567 Location loc = builder.getUnknownLoc();
568 builder.create<YieldOp>(loc, values);
571 Value constant(
const std::string &value) {
573 builder.setInsertionPointToEnd(&block);
574 Location loc = builder.getUnknownLoc();
576 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
579 Value index(int64_t dim) {
581 builder.setInsertionPointToEnd(&block);
582 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
585 Type getIntegerType(
unsigned width) {
599 builder.setInsertionPointToEnd(&block);
600 auto loc = operand.
getLoc();
604 bool isComplex(
Value value) {
605 return llvm::isa<ComplexType>(value.
getType());
607 bool isFloatingPoint(
Value value) {
608 return llvm::isa<FloatType>(value.
getType());
610 bool isInteger(
Value value) {
611 return llvm::isa<IntegerType>(value.
getType());
628 LogicalResult matchAndRewrite(CopyOp copyOp,
630 if (copyOp.getInputs() != copyOp.getOutputs())
632 if (copyOp.hasPureBufferSemantics())
635 rewriter.
replaceOp(copyOp, copyOp.getInputs());
645 results.
add<EraseSelfCopy>(context);
658 template <
typename TensorReshapeOp>
661 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
663 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
668 TensorReshapeOp newInit;
669 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
671 newInit = rewriter.
create<TensorReshapeOp>(
672 loc, reshapeOp.getResultType(), oldFill.output(),
673 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
674 reshapeOp.getStaticOutputShape());
676 newInit = rewriter.
create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
678 reshapeOp.getReassociation());
691 LogicalResult matchAndRewrite(tensor::PadOp padOp,
693 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
699 Value padValue = padOp.getConstantPaddingValue();
700 if (!padValue || fillOp.value() != padValue)
706 padOp,
"failed to reify tensor.pad op result shape");
708 auto emptyTensor = rewriter.
create<tensor::EmptyOp>(
709 padOp.getLoc(), reifiedShape.front(),
710 padOp.getResultType().getElementType());
716 if (replacement.getType() != padOp.getResultType()) {
717 replacement = rewriter.
create<tensor::CastOp>(
718 fillOp.getLoc(), padOp.getResultType(), replacement);
728 struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
731 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
733 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
737 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
742 Value firstDest = insertOp.getDest();
743 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
744 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
749 bool disjoint =
false;
750 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
753 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
754 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
755 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
759 int64_t prevStart = prevOp.getStaticOffset(i);
760 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
761 prevOp.getStaticStride(i);
762 int64_t nextStart = insertOp.getStaticOffset(i);
763 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
764 insertOp.getStaticStride(i);
765 if (prevEnd < nextStart || nextEnd < prevStart) {
773 firstDest = prevOp.getDest();
784 Value padValue = srcPadOp.getConstantPaddingValue();
785 if (!padValue || dstFillOp.value() != padValue)
801 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
803 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
806 RankedTensorType srcPadType = srcPadOp.getSourceType();
808 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
809 if (srcPadType.isDynamicDim(i)) {
811 rewriter.
create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
814 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
819 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
820 newSizes, insertOp.getMixedStrides());
826 struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
830 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
834 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
839 Value extractedScalar = fillOp.getInputs()[0];
842 rewriter.
replaceOp(extractOp, extractedScalar);
850 static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
851 linalg::PackOp packOp) {
852 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
856 if (
auto paddingValue = packOp.getPaddingValue())
860 Value packOpDest = packOp.getDest();
864 return rewriter.
create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
874 LogicalResult matchAndRewrite(linalg::PackOp packOp,
876 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
879 rewriter.
replaceOp(packOp, fillOp.value().result());
888 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
890 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
893 copyOp.getOutputs());
896 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
898 fillOp.getOutputs());
909 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
911 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
913 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
914 transposeOp.getDpsInitOperand(0)->get());
926 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
928 auto concatOperands = concatOp.getInputs();
929 if (concatOperands.empty()) {
933 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
942 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
944 auto isDefinedByCompatibleFillOp = [&](
Value v) ->
bool {
945 auto fillOp = v.getDefiningOp<linalg::FillOp>();
952 if (fillVal != firstFillVal)
955 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
958 if (!llvm::all_of(concatOperands.drop_front(),
959 isDefinedByCompatibleFillOp)) {
961 concatOp,
"not all operands are defined by a compatible fill op");
964 Value outsConcat = rewriter.
create<tensor::ConcatOp>(
965 concatOp.getLoc(), concatOp.getDim(), allOuts);
967 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
976 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
977 FoldFillWithPack, FoldFillWithPad,
978 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
979 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
980 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
993 for (
ValueRange container : {inputs, outputs}) {
994 for (
Value v : container) {
995 Type t = v.getType();
996 blockArgTypes.push_back(
998 blockArgLocs.push_back(v.getLoc());
1004 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1008 void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1010 for (
Value v : getRegionInputArgs())
1012 for (
Value v : getRegionOutputArgs())
1013 setNameFn(v,
"out");
1016 void GenericOp::build(
1019 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1022 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1023 iteratorTypes, doc, libraryCall);
1027 inputs, outputs, bodyBuild);
1030 void GenericOp::build(
1034 StringRef libraryCall,
1037 build(builder, result, resultTensorTypes, inputs, outputs,
1042 return IteratorTypeAttr::get(builder.getContext(), iter);
1045 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1046 bodyBuild, attributes);
1049 void GenericOp::build(
1053 StringRef libraryCall,
1056 build(builder, result,
TypeRange{}, inputs, outputs, indexingMaps,
1057 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1060 void GenericOp::build(
1066 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1068 "", bodyBuild, attributes);
1071 void GenericOp::build(
1077 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1080 "", bodyBuild, attributes);
1087 auto genericAttrNames = linalgTraitAttrNames();
1090 genericAttrNamesSet.insert_range(genericAttrNames);
1092 for (
auto attr : (*this)->getAttrs()) {
1093 if (attr.getName() == getIteratorTypesAttrName()) {
1094 auto iteratorTypes =
1095 llvm::cast<ArrayAttr>(attr.getValue())
1096 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1102 llvm::to_vector(llvm::map_range(
1103 iteratorTypes, [&](utils::IteratorType t) ->
Attribute {
1107 genericAttrs.emplace_back(
1108 getIteratorTypesAttrName(),
1110 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1111 genericAttrs.push_back(attr);
1114 if (!genericAttrs.empty()) {
1116 p << genericDictAttr;
1122 genericAttrNames.push_back(
"operandSegmentSizes");
1123 genericAttrNamesSet.insert(genericAttrNames.back());
1125 bool hasExtraAttrs =
false;
1127 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1130 if (hasExtraAttrs) {
1137 if (!getRegion().empty()) {
1147 DictionaryAttr dictAttr;
1156 dictAttr.getValue().end());
1162 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1164 if (!iteratorTypes) {
1165 return parser.
emitError(attributeLocation)
1166 <<
"expected " << getIteratorTypesAttrName(result.
name)
1167 <<
" array attribute";
1172 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1173 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1174 if (!maybeIteratorType.has_value())
1176 <<
"unexpected iterator_type (" << s <<
")";
1178 iteratorTypeAttrs.push_back(
1195 std::unique_ptr<Region> region = std::make_unique<Region>();
1207 result.
addTypes(outputTensorsTypes);
1215 LinalgOp linalgOp) {
1216 for (
auto [index, operand] :
llvm::enumerate(linalgOp.getDpsInputs())) {
1217 if (!llvm::isa<MemRefType>(operand.
getType()))
1219 effects.emplace_back(
1224 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1225 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1227 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1238 void GenericOp::getEffects(
1248 if (!linalgOp.hasPureTensorSemantics())
1267 template <
typename OpTy>
1271 LogicalResult matchAndRewrite(OpTy linalgOp,
1274 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1279 Block &body = linalgOp->getRegion(0).
front();
1280 if (!llvm::hasSingleElement(body))
1282 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1287 if (linalgOp.hasPureBufferSemantics()) {
1288 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1289 linalgOp.getDpsInputOperand(0)->get() ==
1290 linalgOp.getDpsInitOperand(0)->get()) {
1298 if (!linalgOp.hasPureTensorSemantics())
1305 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1306 if (!yieldArg || yieldArg.getOwner() != &body)
1308 unsigned argumentNumber = yieldArg.getArgNumber();
1309 Value returnedArg = linalgOp->getOperand(argumentNumber);
1310 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1314 if (returnType != resultType) {
1319 returnedArg = rewriter.
create<sparse_tensor::ConvertOp>(
1320 linalgOp.getLoc(), resultType, returnedArg);
1322 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1325 returnedArg = rewriter.
create<tensor::CastOp>(
1326 linalgOp.getLoc(), resultType, returnedArg);
1329 returnedArgs.push_back(returnedArg);
1332 if (returnedArgs.size() != linalgOp->getNumResults())
1334 rewriter.
replaceOp(linalgOp, returnedArgs);
1343 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1365 for (
Type outputType : outputTypes) {
1366 if (llvm::isa<RankedTensorType>(outputType))
1371 if (parseAttrsFn && failed(parseAttrsFn(parser, result.
attributes)))
1380 void MapOp::getAsmBlockArgumentNames(
Region ®ion,
1382 for (
Value v : getRegionInputArgs())
1387 if (!getResults().empty())
1388 setNameFn(getResults().front(),
"mapped");
1395 build(builder, result,
TypeRange{}, inputs, init);
1400 if (llvm::isa<RankedTensorType>(initType))
1405 inputs, {}, bodyBuild);
1412 bool initFirst =
false) {
1418 for (
auto &operand : operands) {
1420 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1427 payloadOpOperands.push_back(block.
getArguments().back());
1428 for (
const auto &arg : block.
getArguments().drop_back())
1429 payloadOpOperands.push_back(arg);
1438 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1445 std::optional<OperationName> payloadOpName;
1449 if (failed(operationName))
1453 payloadOpName = operationName.value();
1461 if (payloadOpName.has_value()) {
1499 for (
const auto &[operand, bbArg] :
1501 if (bbArg != operand)
1505 for (
const auto &[operand, bbArg] :
1507 if (bbArg != operand)
1516 std::string attrToElide;
1518 for (
const auto &attr : payloadOp->
getAttrs()) {
1520 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1521 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1522 attrToElide = attr.getName().str();
1523 elidedAttrs.push_back(attrToElide);
1532 Block *mapper = getBody();
1547 [&](
auto arg) { p.printRegionArgument(arg); });
1556 auto *bodyBlock = getBody();
1557 auto blockArgs = bodyBlock->getArguments();
1560 if (getInputs().size() != blockArgs.size())
1561 return emitOpError() <<
"expects number of operands to match the arity of "
1563 << getInputs().size() <<
" and " << blockArgs.size();
1566 for (
const auto &[bbArgType, inputArg] :
1567 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1568 auto inputElemType =
1569 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1570 if (bbArgType != inputElemType) {
1571 return emitOpError() <<
"expected element type of input " << inputElemType
1572 <<
" to match bbArg type " << bbArgType;
1577 auto outputShape = getInit().getType().getShape();
1579 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1580 if (inputElemShape != outputShape) {
1581 return emitOpError() <<
"expected shape of input (" << inputElemShape
1582 <<
") to match shape of output (" << outputShape
1591 int64_t rank = getInit().getType().getRank();
1595 ArrayAttr MapOp::getIndexingMaps() {
1597 int64_t rank = getInit().getType().getRank();
1598 int64_t numIndexingMaps = getOperands().size();
1603 void MapOp::getEffects(
1617 void ReduceOp::getAsmBlockArgumentNames(
Region ®ion,
1619 for (
Value v : getRegionInputArgs())
1621 for (
Value v : getRegionOutputArgs())
1622 setNameFn(v,
"init");
1625 void ReduceOp::getAsmResultNames(
1627 if (!getResults().empty())
1628 setNameFn(getResults().front(),
"reduced");
1631 void ReduceOp::build(
1636 build(builder, result,
TypeRange{}, inputs, inits, dimensions);
1640 for (
Value init : inits) {
1642 if (llvm::isa<RankedTensorType>(initType))
1648 inputs, inits, bodyBuild);
1653 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1655 utils::IteratorType::parallel);
1656 for (int64_t reductionDim : getDimensions())
1657 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1658 return iteratorTypes;
1661 ArrayAttr ReduceOp::getIndexingMaps() {
1663 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1670 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1671 affineMaps.push_back(resultMap);
1675 void ReduceOp::getEffects(
1687 StringRef attributeName) {
1696 std::optional<OperationName> payloadOpName;
1700 if (failed(operationName))
1704 payloadOpName = operationName.value();
1715 if (payloadOpName.has_value()) {
1735 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1739 Block *mapper = getBody();
1754 [&](
auto arg) { p.printRegionArgument(arg); });
1765 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1766 if (llvm::cast<ShapedType>(getInputs()[i].
getType()).getShape() !=
1767 llvm::cast<ShapedType>(getInputs()[0].
getType()).getShape()) {
1768 return emitOpError() <<
"expects all inputs to have the same shapes. "
1769 "Shape at input-index "
1771 <<
" is not equal to the shape at input-index 0.";
1774 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1775 if (llvm::cast<ShapedType>(getInits()[i].
getType()).getShape() !=
1776 llvm::cast<ShapedType>(getInits()[0].
getType()).getShape()) {
1777 return emitOpError() <<
"expects all outputs to have the same shapes. "
1778 "Shape at output-index "
1780 <<
" is not equal to the shape at output-index 0.";
1783 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1784 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1787 for (int64_t dimension : dimensionsRef) {
1788 if (dimension < 0 || dimension >= inputType.getRank()) {
1789 return emitOpError()
1790 <<
"dimensions for reduction should be in the range [0, "
1791 << inputType.getRank() - 1 <<
"].";
1793 dimensionsToReduce.insert(dimension);
1796 auto inputDims = inputType.getShape();
1797 auto initDims = initType.getShape();
1802 if (!dimensionsToReduce.count(en.index()))
1803 reducedInputDims.push_back(en.value());
1806 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1807 return emitOpError() <<
"number of dimensions after reduction "
1808 << reducedInputDims.size()
1809 <<
" doesn't match the init rank "
1810 << initType.getRank();
1813 if (reducedInputDims != initDims)
1814 return emitOpError() <<
"init dimensions [" << initDims
1815 <<
"] doesn't match input dimensions after reduction ["
1816 << reducedInputDims <<
"]";
1818 Block *block = getBody();
1820 return emitOpError()
1821 <<
"mismatching number of operands and block arguments";
1824 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1825 Type inputElementType =
1826 llvm::cast<ShapedType>(input.getType()).getElementType();
1827 if (inputElementType != bbArg.getType())
1828 return emitOpError()
1829 <<
"input element type " << inputElementType
1830 <<
" does not match corresponding block argument type "
1835 for (
auto [output, bbArg] : llvm::zip(
1836 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1837 auto outputElementType =
1838 llvm::cast<ShapedType>(output.getType()).getElementType();
1839 if (outputElementType != bbArg.getType())
1840 return emitOpError()
1841 <<
"output element type " << outputElementType
1842 <<
" does not match corresponding block argument type "
1858 b.
create<linalg::YieldOp>(loc, args[0]);
1873 if (llvm::isa<RankedTensorType>(initType))
1902 void TransposeOp::getAsmResultNames(
1904 if (!getResults().empty())
1905 setNameFn(getResults().front(),
"transposed");
1918 return emitOpError(
"permutation is not valid");
1920 auto inputType = getInput().getType();
1921 auto initType = getInit().getType();
1923 int64_t rank = inputType.getRank();
1925 if (rank != initType.getRank())
1926 return emitOpError() <<
"input rank " << rank
1927 <<
" does not match init rank " << initType.getRank();
1929 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
1930 return emitOpError() <<
"size of permutation " << permutationRef.size()
1931 <<
" does not match the argument rank " << rank;
1933 auto inputDims = inputType.getShape();
1934 auto initDims = initType.getShape();
1936 for (int64_t i = 0; i < rank; ++i) {
1937 int64_t inputDim = inputDims[permutationRef[i]];
1938 int64_t initDim = initDims[i];
1940 if (inputDim != initDim) {
1941 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
1942 <<
" doesn't match dim(input, permutation[" << i
1943 <<
"]) = " << inputDim;
1951 int64_t rank = getInit().getType().getRank();
1955 ArrayAttr TransposeOp::getIndexingMaps() {
1957 int64_t rank = getInit().getType().getRank();
1960 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
1964 void TransposeOp::getEffects(
1974 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1977 if (!isa<TensorType>(getInput().
getType()))
1981 if (getPermutation().size() == 0) {
1982 result.push_back(getInput());
1987 result.push_back(getInput());
2000 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2001 if (!defTransposeOp)
2006 foldedPerms.reserve(perms.size());
2007 for (int64_t perm : perms)
2008 foldedPerms.push_back(defPerms[perm]);
2011 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2025 Value input = transposeOp.getInput();
2026 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2037 unsigned dimensionSize = dimensions.size();
2038 for (
unsigned i = 0; i < dimensionSize; ++i)
2039 resultDimensions.push_back(invertPerm[dimensions[i]]);
2042 Value broadcastInput = broadcastOp.getInput();
2043 Location loc = transposeOp.getLoc();
2046 auto broadcastInputTy =
2047 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2048 unsigned inputRank = broadcastInputTy.getRank();
2049 for (
unsigned i = 0; i < inputRank; ++i) {
2050 if (broadcastInputTy.isDynamicDim(i)) {
2051 dims.push_back(rewriter.
create<tensor::DimOp>(loc, broadcastInput, i)
2055 broadcastInputTy.getDimSize(i)));
2060 Value transposeInit = rewriter.
create<tensor::EmptyOp>(
2061 transposeOp.getLoc(), transposeResultShapes,
2062 broadcastInputTy.getElementType());
2065 Value transposeResult =
2067 .
create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2071 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2096 if (llvm::isa<RankedTensorType>(initType))
2125 void BroadcastOp::getAsmResultNames(
2127 if (!getResults().empty())
2128 setNameFn(getResults().front(),
"broadcasted");
2140 auto inputType = getInput().getType();
2141 auto initType = getInit().getType();
2143 int64_t inputRank = inputType.getRank();
2144 int64_t initRank = initType.getRank();
2146 auto inputShape = inputType.getShape();
2147 auto initShape = initType.getShape();
2149 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2150 return emitOpError() <<
"input rank plus added dimensions does not "
2151 "match init rank. input rank: "
2153 <<
", dimensions size: " << dimensionsRef.size()
2154 <<
", init rank: " << initRank;
2157 if (dim < 0 || dim >= initRank)
2158 return emitOpError() <<
"dimension " << idx
2159 <<
" is out of range. expected range: [0, "
2160 << initRank - 1 <<
"], got: " << dim;
2165 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2166 if (!llvm::is_contained(dimensionsRef, dim))
2167 dimMap.push_back(dim);
2170 for (
const auto &[inputDimIdx, initDimIdx] :
llvm::enumerate(dimMap)) {
2173 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2174 return emitOpError() <<
"input dim " << inputDimIdx
2175 <<
" should match init dim " << initDimIdx
2176 <<
". input: " << inputShape[inputDimIdx]
2177 <<
", init: " << initShape[initDimIdx];
2184 int64_t rank = getInit().getType().getRank();
2188 ArrayAttr BroadcastOp::getIndexingMaps() {
2190 int64_t rank = getInit().getType().getRank();
2196 void BroadcastOp::getEffects(
2208 results.
add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2216 if (getNumOperands() > 0)
2217 p <<
' ' << getOperands();
2219 if (getNumOperands() > 0)
2220 p <<
" : " << getOperandTypes();
2235 static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2236 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2237 return op.emitOpError(
"expected number of yield values (")
2238 << op.getNumOperands()
2239 <<
") to match the number of inits / outs operands of the enclosing "
2240 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2242 for (
OpOperand &opOperand : op->getOpOperands()) {
2244 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2246 if (isa<MemRefType, RankedTensorType>(elementType))
2248 if (opOperand.get().getType() != elementType)
2249 return op.emitOpError(
"type of yield operand ")
2250 << (opOperand.getOperandNumber() + 1) <<
" ("
2251 << opOperand.get().getType() <<
") doesn't match "
2252 <<
"the element type of the enclosing linalg.generic op ("
2253 << elementType <<
")";
2259 auto *parentOp = (*this)->getParentOp();
2260 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2261 return emitOpError(
"expected single non-empty parent region");
2263 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2266 return emitOpError(
"expected parent op with LinalgOp interface");
2274 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2276 return emitOpError(
"expected parent op with LinalgOp interface");
2277 if (linalgOp.getNumLoops() <= getDim())
2278 return emitOpError(
"expected dim (")
2279 << getDim() <<
") to be lower than the number of loops ("
2280 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2286 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2288 #define GET_OP_CLASSES
2289 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2291 #define GET_OP_CLASSES
2292 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2293 #define GET_OP_CLASSES
2294 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2311 for (
unsigned i = 0; i < num; ++i)
2318 auto rangeA = llvm::make_range(a.begin(), a.end());
2319 auto rangeB = llvm::make_range(b.begin(), b.end());
2320 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2321 return llvm::to_vector<4>(concatRanges);
2325 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2327 for (
auto size : memref.getShape())
2334 if (
auto as = memref.getMemorySpace()) {
2335 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2336 ss <<
"as" << attr.getInt();
2342 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2345 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2358 assert(isa<LinalgOp>(op));
2360 std::string fun =
"";
2362 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2363 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2364 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2365 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2369 std::replace(name.begin(), name.end(),
'.',
'_');
2370 llvm::raw_string_ostream ss(name);
2374 return std::string();
2389 LogicalResult matchAndRewrite(LinalgOp op,
2391 for (
OpOperand &opOperand : op->getOpOperands()) {
2395 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2398 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2409 struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2412 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2417 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2424 if (castOp->getBlock() != linalgOp->getBlock())
2431 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2434 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2440 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2442 rewriter.
create<tensor::CastOp>(loc, resultType, outOperand->
get());
2445 linalgOp.getDpsInits().end());
2446 outputOperands[resultNumber] = newOperand;
2447 newOperands.append(outputOperands.begin(), outputOperands.end());
2450 linalgOp->result_type_end());
2451 resultTypes[resultNumber] = resultType;
2452 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2459 results[resultNumber] = castBack;
2471 if (linalgOp.isScalar(&opOperand))
2473 Value src = opOperand.get();
2474 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2475 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2483 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2484 Value castSource = castOp.getSource();
2485 auto castSourceType =
2486 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2487 if (castSourceType && castSourceType.hasStaticShape())
2488 sourceShape = castSourceType.getShape();
2494 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2495 if (sourceType.isDynamicDim(i))
2497 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2498 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2508 static void createNewOperandWithStaticSizes(
2512 bool &changeNeeded) {
2514 newOperands.push_back(src);
2515 if (linalgOp.isScalar(opOperand))
2517 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2518 Type resultType = sourceType;
2519 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2520 resultTypes.push_back(resultType);
2524 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2528 bool newOperandNeeded =
false;
2529 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2530 int64_t dimShape = sourceShape[i];
2532 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2533 newShape.push_back(dimShape);
2539 newShape.push_back(affineExprToSize[dimExpr]);
2540 newOperandNeeded =
true;
2543 sourceType.getEncoding());
2544 if (newOperandNeeded) {
2545 changeNeeded =
true;
2548 Value newOperand = rewriter.
create<tensor::CastOp>(loc, resultType, src);
2550 newOperands[index] = newOperand;
2552 if (linalgOp.isDpsInit(opOperand))
2553 resultTypes.push_back(resultType);
2562 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2564 if (!linalgOp.hasPureTensorSemantics())
2568 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](
AffineMap map) {
2569 return !map.isProjectedPermutation();
2579 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2586 bool changeNeeded =
false;
2587 newOperands.reserve(linalgOp->getNumOperands());
2588 resultTypes.reserve(linalgOp.getNumDpsInits());
2591 for (
OpOperand &opOperand : linalgOp->getOpOperands()) {
2592 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2593 affineExprToSize, linalgOp, newOperands,
2594 resultTypes, changeNeeded);
2603 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2606 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2607 Value newResult = std::get<1>(it);
2608 Value oldResult = std::get<0>(it);
2611 replacements.push_back(
2612 (newType != oldType)
2613 ? rewriter.
create<tensor::CastOp>(loc, oldType, newResult)
2616 rewriter.
replaceOp(linalgOp, replacements);
2631 ShapedType inputType = getInputOperandType();
2632 ShapedType outputType = getOutputOperandType();
2637 return emitOpError(
"incompatible output shape");
2639 int64_t inputRank = getInputOperandRank();
2640 int64_t dimension = getDimension();
2641 if ((dimension < 0) || (dimension >= inputRank))
2642 return emitOpError(
"incorrect dimension specified");
2648 int64_t operandRank = getInputOperandRank();
2651 Value zero = builder.
create<arith::ConstantIndexOp>(loc, 0);
2652 Value one = builder.
create<arith::ConstantIndexOp>(loc, 1);
2653 Value source = getInput();
2654 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2655 loopBounds[dim].offset = zero;
2656 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2657 loopBounds[dim].stride = one;
2664 utils::IteratorType::parallel);
2665 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2666 return iteratorTypes;
2669 FailureOr<TilingResult>
2673 int64_t rank = getInputOperandRank();
2678 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2680 return emitOpError(
"failed to compute input slice");
2682 tiledOperands.emplace_back(inputSlice->
getResult(0));
2684 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2686 return emitOpError(
"failed to compute output slice");
2688 tiledOperands.emplace_back(outputSlice->
getResult(0));
2691 if (hasPureTensorSemantics())
2692 resultTypes.push_back(tiledOperands[1].
getType());
2694 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2706 if (resultNumber == 0) {
2707 resultOffsets.assign(offsets.begin(), offsets.end());
2708 resultSizes.assign(sizes.begin(), sizes.end());
2723 Location loc = getOperation()->getLoc();
2725 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2726 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2727 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2728 if (!outputShapedType.isDynamicDim(dim)) {
2730 shapes.push_back(b.
getIndexAttr(inputShapedType.getDimSize(dim)));
2737 reifiedReturnShapes.emplace_back(std::move(shapes));
2741 void SoftmaxOp::getEffects(
2745 if (!llvm::isa<MemRefType>(operand.
getType()))
2748 &getOperation()->getOpOperand(index), 0,
2753 for (
OpOperand &operand : getDpsInitsMutable()) {
2754 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2787 int64_t dim,
bool allParallel =
false) {
2789 utils::IteratorType::parallel);
2791 iteratorTypes[dim] = utils::IteratorType::reduction;
2795 for (
int i = 0; i < inputRank; i++) {
2802 return std::make_tuple(iteratorTypes, indexingMaps);
2807 template <
typename T>
2810 auto inputType = cast<ShapedType>(input.
getType());
2812 int64_t inputRank = inputShape.size();
2813 auto [iteratorTypes, indexingMaps] =
2815 assert(indexingMaps.size() == 2 &&
2816 "We should have two maps: 1 for the input, 1 for the output");
2817 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2819 auto genericOp = builder.
create<linalg::GenericOp>(
2820 loc, output.
getType(), input, output, indexingMaps, iteratorTypes,
2822 Value result = b.create<T>(loc, args[0], args[1]);
2823 b.create<linalg::YieldOp>(loc, result);
2833 auto inputType = cast<ShapedType>(input.
getType());
2835 int64_t inputRank = inputShape.size();
2837 builder, inputRank, dim,
true);
2838 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
2839 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2841 indexingMaps.push_back(indexingMaps[0]);
2842 auto genericOp = builder.
create<linalg::GenericOp>(
2845 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2846 Value result = b.create<math::ExpOp>(loc, diff);
2847 b.create<linalg::YieldOp>(loc, result);
2858 Value denominator,
Value output, int64_t dim) {
2859 auto inputType = cast<ShapedType>(numerator.
getType());
2861 int64_t inputRank = inputShape.size();
2863 builder, inputRank, dim,
true);
2864 assert(indexingMaps.size() == 2 &&
2865 "We should have one map for each input (2)");
2866 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
2868 indexingMaps.push_back(indexingMaps[0]);
2869 auto genericOp = builder.
create<linalg::GenericOp>(
2871 indexingMaps, iteratorTypes,
2873 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2874 b.create<linalg::YieldOp>(loc, result);
2898 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(
OpBuilder &b) {
2902 Value input = getInput();
2903 ShapedType inputType = getInputOperandType();
2904 Type elementType = inputType.getElementType();
2905 int64_t reductionDim = getDimension();
2907 Value output = getOutput();
2908 dims.erase(dims.begin() + reductionDim);
2910 Value outputReduce = b.
create<tensor::EmptyOp>(loc, dims, elementType);
2912 elementType, b, loc,
2914 Value neutralForMaxFInit =
2915 b.
create<linalg::FillOp>(loc,
Value{neutralForMaxF}, outputReduce)
2918 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2927 b.
create<linalg::FillOp>(loc,
Value{zero}, outputReduce).result();
2929 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2933 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2942 auto filterType = cast<ShapedType>(getFilter().
getType());
2944 int64_t filterH = filterShape[getFilterHDim()];
2945 int64_t filterW = filterShape[getFilterWDim()];
2949 if (filterH != r && filterH != 1)
2950 return emitOpError(
"expect filter height either equals to r or 1");
2951 if (filterW != r && filterW != 1)
2952 return emitOpError(
"expect filter width either equals to r or 1");
2953 if (filterH == 1 && filterW == 1)
2954 return emitOpError(
"expect either filter height or width equals to r");
2957 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2958 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2959 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2960 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2962 auto outputType = cast<ShapedType>(getOutput().
getType());
2965 return emitOpError(
"the output shape is not expected");
2971 WinogradFilterTransformOp::getIterationDomain(
OpBuilder &builder) {
2975 Value filter = getFilter();
2976 int64_t filterRank = getFilterOperandRank();
2978 for (
unsigned dim = 0; dim < filterRank; ++dim) {
2979 loopBounds[dim].offset = zeroAttr;
2980 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
2981 loopBounds[dim].stride = oneAttr;
2987 WinogradFilterTransformOp::getLoopIteratorTypes() {
2988 int64_t filterRank = getFilterOperandRank();
2990 utils::IteratorType::parallel);
2991 return iteratorTypes;
2999 ShapedType filterType = getFilterOperandType();
3001 int64_t filterH = filterShape[getFilterHDim()];
3002 int64_t filterW = filterShape[getFilterWDim()];
3005 int64_t alpha = m + r - 1;
3006 int64_t alphaH = filterH != 1 ? alpha : 1;
3007 int64_t alphaW = filterW != 1 ? alpha : 1;
3011 resultOffsets.append(
3012 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3014 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3030 ShapedType filterType = getFilterOperandType();
3032 int64_t filterH = filterShape[getFilterHDim()];
3033 int64_t filterW = filterShape[getFilterWDim()];
3039 sliceOffsets.append(
3040 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3041 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3042 sizes[getFilterCDim()]});
3043 int64_t filterRank = getFilterOperandRank();
3046 auto filterSlice = builder.
create<tensor::ExtractSliceOp>(
3047 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3048 tiledOperands.emplace_back(filterSlice);
3055 int64_t outputRank = getOutputOperandRank();
3057 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3058 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3059 tiledOperands.emplace_back(outputSlice);
3062 resultTypes.push_back(tiledOperands[1].
getType());
3064 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3077 auto inputType = cast<ShapedType>(getInput().
getType());
3079 int64_t inputH = inputShape[getInputHDim()];
3080 int64_t inputW = inputShape[getInputWDim()];
3083 int64_t tileSize = m + r - 1;
3085 auto outputType = cast<ShapedType>(getOutput().
getType());
3087 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3088 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3091 if (ShapedType::isDynamic(inputH)) {
3092 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3093 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3095 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3096 expectedOutputShape[getOutputTileHDim()] =
3097 leftTransform ? (inputH - (r - 1)) / m : inputH;
3099 if (ShapedType::isDynamic(inputW)) {
3100 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3101 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3103 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3104 expectedOutputShape[getOutputTileWDim()] =
3105 rightTransform ? (inputW - (r - 1)) / m : inputW;
3107 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3108 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3111 return emitOpError(
"the output shape is not expected");
3117 WinogradInputTransformOp::getIterationDomain(
OpBuilder &builder) {
3121 Value output = getOutput();
3122 int64_t outputRank = getOutputOperandRank();
3124 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3125 loopBounds[dim].offset = zeroAttr;
3127 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3128 loopBounds[dim].stride = oneAttr;
3134 WinogradInputTransformOp::getLoopIteratorTypes() {
3135 int64_t outputRank = getOutputOperandRank();
3137 utils::IteratorType::parallel);
3138 return iteratorTypes;
3146 ShapedType outputType = getOutputOperandType();
3148 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3149 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3153 int64_t alpha = m + r - 1;
3154 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3155 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3160 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3161 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3162 offsets[getOutputCDim()]});
3163 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3164 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3165 sizes[getOutputCDim()]});
3176 FailureOr<TilingResult>
3184 ShapedType outputType = getOutputOperandType();
3186 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3187 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3191 auto identityAffineMap =
3193 auto offsetAffineMap =
3196 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3197 offsets[getOutputTileHDim()]);
3199 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3200 offsets[getOutputTileWDim()]);
3204 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3206 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3213 sliceOffsets.append(
3214 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3220 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3221 int64_t inputRank = getInputOperandRank();
3223 auto inputSlice = builder.
create<tensor::ExtractSliceOp>(
3224 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3225 tiledOperands.emplace_back(inputSlice);
3232 int64_t outputRank = getOutputOperandRank();
3234 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3235 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3236 tiledOperands.emplace_back(outputSlice);
3239 resultTypes.push_back(tiledOperands[1].
getType());
3241 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3254 auto valueType = cast<ShapedType>(getValue().
getType());
3256 int64_t valueH = valueShape[getValueAlphaHDim()];
3257 int64_t valueW = valueShape[getValueAlphaWDim()];
3258 int64_t valueTileH = valueShape[getValueTileHDim()];
3259 int64_t valueTileW = valueShape[getValueTileWDim()];
3262 bool leftTransform = valueH != 1;
3263 bool rightTransform = valueW != 1;
3265 int64_t outputRank = getOutputOperandRank();
3267 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3268 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3270 if (valueH != (leftTransform ? m + r - 1 : 1))
3271 return emitOpError(
"expect input height equals to input tile size");
3272 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3274 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3275 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3277 if (valueW != (rightTransform ? m + r - 1 : 1))
3278 return emitOpError(
"expect input width equals to input tile size");
3279 expectedOutputShape[getOutputWDim()] =
3280 (rightTransform ? m : 1) * valueTileW;
3282 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3283 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3285 auto outputType = cast<ShapedType>(getOutput().
getType());
3288 return emitOpError(
"the output shape is not expected");
3294 WinogradOutputTransformOp::getIterationDomain(
OpBuilder &builder) {
3298 Value value = getValue();
3299 int64_t valueRank = getValueOperandRank();
3301 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3302 loopBounds[dim].offset = zeroAttr;
3304 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3305 loopBounds[dim].stride = oneAttr;
3311 WinogradOutputTransformOp::getLoopIteratorTypes() {
3312 int64_t valueRank = getValueOperandRank();
3314 utils::IteratorType::parallel);
3315 return iteratorTypes;
3326 auto identityAffineMap =
3331 ShapedType valueType = getValueOperandType();
3333 int64_t valueH = valueShape[0];
3334 int64_t valueW = valueShape[1];
3336 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3337 offsets[getValueTileHDim()]);
3339 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3340 offsets[getValueTileWDim()]);
3342 builder, loc, affineMap, sizes[getValueTileHDim()]);
3344 builder, loc, affineMap, sizes[getValueTileWDim()]);
3354 resultOffsets.append(
3355 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3357 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3376 ShapedType valueType = getValueOperandType();
3378 int64_t alphaH = valueShape[getValueAlphaHDim()];
3379 int64_t alphaW = valueShape[getValueAlphaWDim()];
3383 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3384 offsets[getValueTileWDim()], offsets[getValueNDim()],
3385 offsets[getValueFDim()]});
3386 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3387 sizes[getValueTileWDim()], sizes[getValueNDim()],
3388 sizes[getValueFDim()]});
3389 int64_t valueRank = getValueOperandRank();
3391 auto valueSlice = builder.
create<tensor::ExtractSliceOp>(
3392 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3393 tiledOperands.emplace_back(valueSlice);
3400 int64_t outputRank = getOutputOperandRank();
3402 auto outputSlice = builder.
create<tensor::ExtractSliceOp>(
3403 loc, getOutput(), resultOffsets, resultSizes, strides);
3404 tiledOperands.emplace_back(outputSlice);
3407 resultTypes.push_back(tiledOperands[1].
getType());
3409 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3428 llvm::set_union(explicitSet, defaultSet);
3429 return explicitSet == defaultSet;
3449 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3451 auto opIndexingMap = opIndexingMaps[opIndex];
3452 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3455 return matmulOp->emitOpError()
3456 <<
"Unexpected dim expression in map result.";
3459 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3460 return matmulOp->emitOpError()
3461 <<
"Invalid broadcast requested, should be (d2).";
3471 AffineMap defaultIndexingMap,
bool isLHS) {
3474 return batchMatmulOp->emitOpError()
3475 <<
"Unexpected result dim expression (outside the set of default "
3480 return batchMatmulOp->emitOpError()
3481 <<
"no. of result dim expressions exceeds 3.";
3483 auto hasValidBatchDim = [](
AffineMap map) {
3490 if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3491 return batchMatmulOp->emitOpError() <<
"Invalid broadcast requested.";
3492 }
else if (!hasValidBatchDim(opIndexingMap)) {
3493 return batchMatmulOp->emitOpError()
3494 <<
"Invalid batch dimension expression.";
3505 return batchMatmulOp->emitOpError()
3506 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3509 auto areValidOutputResultDim = [](
AffineMap outputMap) {
3510 return outputMap.getResult(0).isFunctionOfDim(0) &&
3511 outputMap.getResult(1).isFunctionOfDim(1) &&
3512 outputMap.getResult(2).isFunctionOfDim(2);
3515 if (!areValidOutputResultDim(opIndexingMap))
3516 return batchMatmulOp->emitOpError()
3517 <<
"Invalid output map result dimension.";
3524 static LogicalResult
3528 batchMatmulOp.getIndexingMapsArray();
3530 batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3532 if (opIndexingMaps.size() != 3)
3533 return batchMatmulOp->emitOpError()
3534 <<
"Indexing_map attribute must have 3 affine maps.";
3536 auto opIndexingMap = opIndexingMaps[opIndex];
3537 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3539 if (opIndex == 2 && failed(
verifyOutputMap(batchMatmulOp, opIndexingMap)))
3542 if (failed(
verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3564 return indexingMaps;
3569 utils::IteratorType::parallel,
3570 utils::IteratorType::reduction};
3573 unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3575 std::string MatmulOp::getLibraryCallName() {
3579 bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3583 bool MatmulOp::hasUserDefinedMaps() {
3587 return defaultMaps != explicitMaps;
3595 "MatmulOp regionBuilder expects 3 (>=0) args");
3596 RegionBuilderHelper helper(b, block);
3599 TypeFn castVal = TypeFn::cast_signed;
3600 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3601 return attr.
getName() ==
"cast";
3603 if (castIter != attrs.end()) {
3604 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3612 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3614 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), value3);
3615 yields.push_back(value4);
3616 helper.yieldOutputs(yields);
3620 bool MatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap) {
3621 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3632 ArrayAttr arrayAttr;
3636 if (llvm::any_of(arrayAttr,
3637 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3639 <<
"element of indexing_maps array is not an affine_map";
3646 if (failed(indexingMapsAttr))
3649 if (*indexingMapsAttr ==
nullptr) {
3650 auto indexingMapAttrs = llvm::map_to_vector(
3651 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3656 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3658 MatmulOp::getRegionBuilder());
3663 MatmulOp::getDefaultIndexingMaps(
getContext()),
3665 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3666 p <<
" indexing_maps = [";
3667 llvm::interleaveComma(getIndexingMaps(), p,
3673 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3681 if (!hasUserDefinedMaps())
3684 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3695 void MatmulOp::getEffects(
3698 if (hasPureTensorSemantics())
3712 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3722 for (
auto result : outAffineMap.
getResults()) {
3723 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3724 assert(dimExpr &&
"affine_map is a projected permutation");
3725 dimsInOutput[dimExpr.getPosition()] =
true;
3729 for (
auto dimOccursInOutput : dimsInOutput)
3730 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3731 : utils::IteratorType::reduction);
3733 return iteratorTypes;
3736 unsigned ContractOp::getNumRegionArgs() {
return 3; }
3742 "ContractOp regionBuilder expects 3 args");
3743 RegionBuilderHelper helper(b, block);
3745 TypeFn castSignedness = TypeFn::cast_signed;
3746 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3747 return attr.
getName() ==
"cast";
3749 if (castIter != attrs.end()) {
3750 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3756 Value lhsAtOutType =
3757 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
3758 Value rhsAtOutType =
3759 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
3760 Value productAtOutType =
3761 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3764 helper.yieldOutputs({result});
3769 if (failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
3771 "expected 'indexing_maps' attribute");
3772 result.
addAttribute(
"indexing_maps", *indexingMapsAttr);
3779 p <<
" indexing_maps = [";
3780 llvm::interleaveComma(getIndexingMaps(), p,
3784 p, getOperation(), getInputs(), getOutputs(),
3785 {
"indexing_maps",
"operandSegmentSizes"});
3789 int iterationSpaceDims = -1;
3798 auto checkAffineMapAndType = [&](
AffineMap affineMap,
Type operandType,
3799 bool isInput) -> LogicalResult {
3802 return emitError(
"provided affine_map is not a projected permutation");
3805 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
3807 return emitError(
"ranks of shaped operand and results of corresponding "
3808 "affine_map differ");
3810 return emitError(
"affine_map specifies shaped access while operand has "
3815 if (iterationSpaceDims == -1) {
3819 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
3820 return emitError(
"iteration spaces of provided affine_maps differ");
3825 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3827 llvm_unreachable(
"affine_map is a projected permutation");
3830 inOccurrences[affineDimExpr.getPosition()] += 1;
3832 outOccurrences[affineDimExpr.getPosition()] += 1;
3838 for (
auto &&[affineMap, operandType, isInput] :
3839 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3841 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3845 bool hasContractingDim =
false;
3846 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3847 size_t inOccCount = inOccurrences[dimIndex];
3848 size_t outOccCount = outOccurrences[dimIndex];
3851 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3853 if (inOccCount == 0 && outOccCount == 0)
3854 return emitError() <<
"iteration space dim at index " << dimIndex
3855 <<
" not used to access any operand";
3866 if (inOccCount == 1 && outOccCount != 1)
3868 <<
"iteration space dim at index " << dimIndex
3869 <<
" is neither a contracting dim nor of parallel iteration type";
3872 if (!hasContractingDim)
3873 return emitError(
"'indexing_maps' do not specify a contracting dimension");
3882 void ContractOp::getEffects(
3885 if (hasPureTensorSemantics())
3898 BatchMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
3902 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
3903 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
3904 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
3905 return indexingMaps;
3910 utils::IteratorType::parallel, utils::IteratorType::parallel,
3911 utils::IteratorType::parallel, utils::IteratorType::reduction};
3914 unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
3916 std::string BatchMatmulOp::getLibraryCallName() {
3922 bool BatchMatmulOp::hasUserDefinedMaps() {
3926 return defaultMaps != explicitMaps;
3930 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
bool isLHS) {
3932 "Expected less than 3 result dim expr.");
3933 bool isValid =
false;
3934 enum Indices { batchPos, mPos, nPos, kPos };
3951 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3952 RegionBuilderHelper helper(b, block);
3955 TypeFn castVal = TypeFn::cast_signed;
3956 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3957 return attr.
getName() ==
"cast";
3959 if (castIter != attrs.end()) {
3960 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3965 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
3966 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
3967 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3969 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
3970 yields.push_back(addVal);
3971 helper.yieldOutputs(yields);
3987 if (!isa<AffineMapAttr>(mapAttr)) {
3989 "expected affine map attribute");
3991 indexingMapsAttr.push_back(mapAttr);
4001 if (indexingMapsAttr.empty()) {
4002 indexingMapsAttr = llvm::map_to_vector(
4003 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4010 BatchMatmulOp::getNumRegionArgs(),
4011 BatchMatmulOp::getRegionBuilder());
4016 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4018 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4019 p <<
" indexing_maps = [";
4020 llvm::interleaveComma(getIndexingMaps(), p,
4026 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4035 if (!hasUserDefinedMaps())
4038 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4045 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4050 void BatchMatmulOp::getEffects(
4053 if (hasPureTensorSemantics())
4067 struct ArityGroupAndKind {
4079 unsigned getArityGroupAsUInt(ElementwiseArityGroup
arityGroup) {
4085 constexpr
int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4086 constexpr
int lastBinary =
4087 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4088 constexpr
int lastTernary =
4089 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4091 int val =
static_cast<int>(
kind);
4092 ArityGroupAndKind result;
4094 if (val < lastUnary) {
4095 result.arityGroup = ElementwiseArityGroup::Unary;
4096 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4099 if (val < lastBinary) {
4100 result.arityGroup = ElementwiseArityGroup::Binary;
4101 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4104 if (val >= lastTernary) {
4105 llvm_unreachable(
"unhandled ElementwiseFn");
4107 result.arityGroup = ElementwiseArityGroup::Ternary;
4108 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4113 auto rank = getResultRank();
4118 ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4127 mlir::linalg::ElementwiseKind elemwiseKindVal;
4132 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4133 if (!elemwiseKindAttr)
4135 "expected ElementwiseKind attribute");
4136 elemwiseKindVal = elemwiseKindAttr.getValue();
4139 "expected operation 'kind' attribute");
4155 if (!isa<AffineMapAttr>(mapAttr))
4157 "expected affine map attribute");
4158 indexingMapsAttr.push_back(mapAttr);
4169 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4171 ElementwiseOp::getRegionBuilder())) {
4173 "unable to parse elemwise op");
4177 if (indexingMapsAttr.empty()) {
4181 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4184 "return type needs to be shaped type");
4185 auto numDims = shapedType.getRank();
4186 indexingMapsAttr = llvm::map_to_vector(
4187 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4204 unsigned numDims = getResultRank();
4207 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4211 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4212 p <<
" indexing_maps = [";
4213 llvm::interleaveComma(getIndexingMaps(), p,
4233 ElementwiseKind elemwiseKind;
4234 for (
auto attr : attrs) {
4236 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4237 assert(kindAttr &&
"op kind attribute incorrectly set");
4238 elemwiseKind = kindAttr.getValue();
4245 auto kind = groupAndKind.kind;
4248 &&
"Elementwise regionBuilder number of block args mismatch");
4250 RegionBuilderHelper helper(b, block);
4254 if (
arityGroup == ElementwiseArityGroup::Unary) {
4257 }
else if (
arityGroup == ElementwiseArityGroup::Binary) {
4261 }
else if (
arityGroup == ElementwiseArityGroup::Ternary) {
4266 assert(
false &&
"found unhandled category in elemwise");
4268 yields.push_back(result);
4269 helper.yieldOutputs(yields);
4272 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4277 void ElementwiseOp::getEffects(
4280 if (hasPureTensorSemantics())
4303 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4305 .take_back(mixedTiles.size()),
4307 int64_t shape = std::get<0>(it);
4308 if (shape == ShapedType::kDynamic) {
4309 newMixedTileSizes.push_back(std::get<1>(it));
4316 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
4318 newMixedTileSizes.push_back(
tile);
4321 "tile size and dim size don't match!");
4322 newMixedTileSizes.push_back(
4327 return newMixedTileSizes;
4330 template <
typename OpTy>
4331 static LogicalResult
4334 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4335 "applies to only pack or unpack operations");
4336 int64_t destRank = op.getDestRank();
4338 reifiedReturnShapes[0] =
4343 template <
typename OpTy>
4345 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4346 "applies to only pack or unpack operations");
4350 assert(tiles.size() == dimsToTile.size() &&
4351 "tiles must match indices of dimension to block");
4353 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4354 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4355 return dimAndTileMapping;
4358 template <
typename OpTy>
4360 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4361 "applies to only pack or unpack operations");
4364 unsigned dynamicValIndex = 0;
4365 for (int64_t staticTile : op.getStaticInnerTiles()) {
4366 if (!ShapedType::isDynamic(staticTile))
4369 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4371 return mixedInnerTiles;
4374 template <
typename OpTy>
4376 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4377 "applies to only pack or unpack operations");
4390 size_t dimsPosSize = dimsPos.size();
4391 if (dimsPosSize > rank)
4394 if (dimsPosSize != uniqued.size())
4396 return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4397 return dimPos < 0 || dimPos >=
static_cast<int64_t
>(rank);
4406 sourceShape.size() == limitShape.size() &&
4407 "expected source shape rank, and limit of the shape to have same rank");
4408 return llvm::all_of(
4409 llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4410 int64_t sourceExtent = std::get<0>(it);
4411 int64_t limit = std::get<1>(it);
4412 return ShapedType::isDynamic(sourceExtent) ||
4413 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4417 template <
typename OpTy>
4419 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4420 "applies to only pack or unpack operations");
4421 Operation *op = packOrUnPack.getOperation();
4425 return llvm::any_of(
4431 if (hasZeros(mixedTiles))
4432 return op->
emitError(
"invalid zero tile factor");
4435 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4436 ? packOrUnPack.getSourceType()
4437 : packOrUnPack.getDestType();
4438 size_t unpackedRank = unpackedType.getRank();
4442 return op->
emitError(
"invalid inner_dims_pos vector");
4444 return op->
emitError(
"invalid outer_dims_perm vector");
4445 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4446 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
4450 if (mixedTiles.size() > unpackedRank) {
4451 return op->
emitError(
"tiling factors must be less than or equal to the "
4452 "input rank for pack or output rank for unpack");
4456 "tiling factors must equal the number of dimensions to tile");
4459 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4460 ? packOrUnPack.getDestType()
4461 : packOrUnPack.getSourceType();
4462 size_t packedRank = packedType.getRank();
4464 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4465 if (expectedPackedRank != packedRank) {
4467 "packed rank != (unpacked rank + num tiling factors), got ")
4468 << packedRank <<
" != " << expectedPackedRank;
4474 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4475 unpackedType, packOrUnPack.getStaticTiles(),
innerDimsPos, outerDimPerm);
4476 if (!
areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4477 return op->
emitError(
"the shape of output is not large enough to hold the "
4478 "packed data. Expected at least ")
4479 << expectedPackedType <<
", got " << packedType;
4482 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4484 [](std::tuple<int64_t, OpFoldResult> it) {
4485 int64_t shape = std::get<0>(it);
4486 if (Attribute attr =
4487 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4488 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4489 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4490 return shape == staticTileSize;
4492 return ShapedType::isDynamic(shape);
4494 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
4495 "tiled dimension in the packed type");
4507 struct PackOrUnPackTransposeResult {
4514 template <
typename OpTy>
4515 static PackOrUnPackTransposeResult
4519 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4520 "applies to only pack or unpack operations");
4521 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4522 "some permutation must be non-empty");
4523 PackOrUnPackTransposeResult metadata;
4524 metadata.innerDimsPos =
4526 metadata.innerTiles =
4528 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4529 ? packOrUnPackOp.getSourceRank()
4530 : packOrUnPackOp.getDestRank();
4531 metadata.outerDimsPerm =
4532 packOrUnPackOp.getOuterDimsPerm().empty()
4533 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4535 if (!innerPermutation.empty()) {
4536 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4538 "invalid inner permutation");
4542 if (!outerPermutation.empty()) {
4543 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4545 "invalid outer permutation");
4555 void PackOp::getAsmResultNames(
function_ref<
void(
Value, StringRef)> setNameFn) {
4556 setNameFn(getResult(),
"pack");
4562 std::optional<Value> paddingValue,
4565 "number of tile sizes specified must match the specified number of "
4566 "original dimensions to be tiled");
4570 build(builder, state, dest.
getType(), source, dest,
4571 paddingValue ? *paddingValue :
nullptr,
4597 ShapedType inputType = getSourceType();
4598 int64_t inputRank = inputType.getRank();
4599 return getDestType().getShape().take_front(inputRank);
4604 auto packedShape = getDestType().getShape();
4608 res.push_back(packedShape[index]);
4619 outputShape.take_front(inputShape.size()));
4622 "expected output and outer_dims_perm to have same size");
4627 if (ShapedType::isDynamic(inputShape[pos]))
4631 if (!constantTile) {
4632 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4633 (inputShape[pos] % outputTileSizes[pos] != 0))
4635 }
else if (inputShape[pos] % (*constantTile) != 0) {
4649 auto paddingValue = getPaddingValue();
4652 return emitOpError(
"expected padding_value has ")
4653 << getSourceType().getElementType()
4654 <<
" but got: " << paddingValue.getType();
4657 if (!paddingValue &&
4658 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
4659 getDestType().
getShape(), getOuterDimsPerm(),
4662 "invalid tile factor or output size provided. Only full tiles are "
4663 "supported when padding_value is not set");
4673 for (
auto o : ofrs) {
4675 if (llvm::dyn_cast_if_present<Value>(o))
4676 result.push_back(ShapedType::kDynamic);
4691 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4693 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4694 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4697 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4698 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4706 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4721 builder, loc, ceilDivExpr,
4722 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4726 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4737 for (
unsigned i = 0; i < resultDims.size(); ++i) {
4738 if (!ShapedType::isDynamic(resultTypeShape[i]))
4749 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4771 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
4772 if (ShapedType::isDynamic(value))
4773 mixedSizes.push_back(
4778 for (
auto it : llvm::zip(
innerDimsPos, innerTileSizes)) {
4779 int64_t dimPos = std::get<0>(it);
4781 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4784 applyPermutationToVector<OpFoldResult>(mixedSizes,
outerDimsPerm);
4786 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4787 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
4788 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4795 *
this, innerPermutation, outerPermutation);
4796 Value transposedDest =
4797 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4798 metadata.innerDimsPos, metadata.outerDimsPerm);
4799 return b.
create<PackOp>(loc, getSource(), transposedDest,
4800 metadata.innerDimsPos, metadata.innerTiles,
4801 getPaddingValue(), metadata.outerDimsPerm);
4805 template <
typename OpTy>
4807 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4808 "applies to only pack or unpack operations");
4809 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4811 : op.getSourceType();
4813 for (
auto [dimDest,
tile] : llvm::zip(
4814 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4816 if (!constTileSize || ShapedType::isDynamic(dimDest))
4823 if (getPaddingValue())
4838 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4840 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4852 auto packTiles = packOp.getMixedTiles();
4853 auto unPackTiles = unPackOp.getMixedTiles();
4854 if (packTiles.size() != unPackTiles.size())
4856 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
4865 auto srcType = op.getSourceType();
4866 if (llvm::any_of(op.getInnerDimsPos(),
4867 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4869 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4871 return !PackOp::requirePaddingValue(
4872 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4873 op.getOuterDimsPerm(), op.getMixedTiles());
4880 bool changeNeeded =
false;
4881 srcShape.assign(packOp.getSourceType().getShape().begin(),
4882 packOp.getSourceType().getShape().end());
4883 destShape.assign(packOp.getDestType().getShape().begin(),
4884 packOp.getDestType().getShape().end());
4885 llvm::SmallSetVector<int64_t, 4> innerDims;
4886 innerDims.insert_range(packOp.getInnerDimsPos());
4888 if (!packOp.getOuterDimsPerm().empty())
4890 int srcRank = packOp.getSourceRank();
4891 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
4892 if (innerDims.contains(i))
4895 int64_t destPos = i;
4896 if (!inverseOuterDimsPerm.empty())
4897 destPos = inverseOuterDimsPerm[srcPos];
4898 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4899 ShapedType::isDynamic(destShape[destPos])) {
4902 int64_t size = srcShape[srcPos];
4903 if (ShapedType::isDynamic(size))
4904 size = destShape[destPos];
4905 srcShape[srcPos] = size;
4906 destShape[destPos] = size;
4907 changeNeeded =
true;
4909 return changeNeeded;
4912 LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
4914 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4915 if (unPackOp.getSourceType() != packOp.getDestType())
4917 if (packOp.getPaddingValue() ||
4921 rewriter.
replaceOp(packOp, unPackOp.getSource());
4928 packOp.getPaddingValueMutable().clear();
4937 Value source = packOp.getSource();
4938 if (srcShape != packOp.getSourceType().getShape()) {
4939 auto newSrcType = packOp.getSourceType().clone(srcShape);
4941 rewriter.
create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4943 Value dest = packOp.getDest();
4944 RankedTensorType originalResultType = packOp.getDestType();
4945 bool needUpdateDestType = (destShape != originalResultType.getShape());
4946 if (needUpdateDestType) {
4947 auto newDestType = packOp.getDestType().clone(destShape);
4949 rewriter.
create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4952 packOp.getSourceMutable().assign(source);
4953 packOp.getDestMutable().assign(dest);
4954 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
4957 if (needUpdateDestType) {
4960 rewriter.
create<tensor::CastOp>(loc, originalResultType, packOp);
4969 template <
typename PackOrUnpackOp>
4971 RankedTensorType packedTensorType) {
4972 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4973 std::is_same<PackOrUnpackOp, UnPackOp>::value,
4974 "Function meant for pack/unpack");
4980 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4987 int64_t packedRank = packedTensorType.getRank();
4997 return llvm::all_of(
4998 llvm::seq<int64_t>(0, packedRank - numPackedDims),
4999 [&packedShape](int64_t i) {
return packedShape[i] == 1; });
5002 bool PackOp::isLikePad() {
5003 auto packedTensorType =
5004 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5009 std::optional<Attribute> paddingValue;
5010 if (
auto pad = adaptor.getPaddingValue())
5012 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5013 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5014 getDestType(), paddingValue))
5015 return reshapedSource;
5053 PackOp newOp = rewriter.
create<PackOp>(
5054 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5055 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5059 Value oldResult = op.getResult();
5060 Value newResult = newOp.getResult();
5062 ? rewriter.
create<tensor::CastOp>(
5063 op->getLoc(), oldResult.
getType(), newResult)
5076 void UnPackOp::getAsmResultNames(
5078 setNameFn(getResult(),
"unpack");
5100 ShapedType destType = getDestType();
5101 int64_t destRank = destType.getRank();
5102 return getSourceType().getShape().take_front(destRank);
5107 auto packedShape = getSourceType().getShape();
5111 res.push_back(packedShape[index]);
5133 "number of tile sizes specified must match the specified number of "
5134 "original dimensions to be tiled");
5138 build(builder, state, dest.
getType(), source, dest,
5157 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
5159 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5160 if (srcType.isDynamicDim(i))
5161 mixedSizes.push_back(b.
create<tensor::DimOp>(loc, source, i).
getResult());
5163 mixedSizes.push_back(b.
getIndexAttr(srcType.getDimSize(i)));
5166 applyPermutationToVector<OpFoldResult>(
5170 for (
auto [dimPos, tileSize] : llvm::zip_equal(
innerDimsPos, innerTileSizes))
5171 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5173 auto elemType = srcType.getElementType();
5174 return b.
create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5178 Value transposedSource,
5182 *
this, innerPermutation, outerPermutation);
5183 return b.
create<UnPackOp>(loc, transposedSource, getDest(),
5184 metadata.innerDimsPos, metadata.innerTiles,
5185 metadata.outerDimsPerm);
5192 bool changeNeeded =
false;
5193 srcShape.assign(op.getSourceType().getShape().begin(),
5194 op.getSourceType().getShape().end());
5195 destShape.assign(op.getDestType().getShape().begin(),
5196 op.getDestType().getShape().end());
5197 llvm::SmallSetVector<int64_t, 4> innerDims;
5198 innerDims.insert_range(op.getInnerDimsPos());
5200 if (!op.getOuterDimsPerm().empty())
5202 int destRank = op.getDestRank();
5203 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
5204 if (innerDims.contains(i))
5207 int64_t destPos = i;
5208 if (!inverseOuterDimsPerm.empty())
5209 srcPos = inverseOuterDimsPerm[destPos];
5210 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5211 ShapedType::isDynamic(destShape[destPos])) {
5214 int64_t size = srcShape[srcPos];
5215 if (ShapedType::isDynamic(size))
5216 size = destShape[destPos];
5217 srcShape[srcPos] = size;
5218 destShape[destPos] = size;
5219 changeNeeded =
true;
5221 return changeNeeded;
5224 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5227 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5228 if (packOp.getSourceType() != unPackOp.getDestType())
5230 if (packOp.getPaddingValue() ||
5234 rewriter.
replaceOp(unPackOp, packOp.getSource());
5238 if (
auto dstStyleOp =
5239 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5240 auto destValue = cast<OpResult>(unPackOp.getDest());
5241 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5243 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5251 Value source = unPackOp.getSource();
5252 if (srcShape != unPackOp.getSourceType().getShape()) {
5253 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5254 source = rewriter.
create<tensor::CastOp>(loc, newSrcType,
5255 unPackOp.getSource());
5257 Value dest = unPackOp.getDest();
5258 if (destShape != unPackOp.getDestType().getShape()) {
5259 auto newDestType = unPackOp.getDestType().clone(destShape);
5261 rewriter.
create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5264 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5265 unPackOp.getOuterDimsPerm());
5267 unPackOp, unPackOp.getResult().getType(), newOp);
5274 bool UnPackOp::isLikeUnPad() {
5275 RankedTensorType packedTensorType = getSourceType();
5280 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5281 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5283 return reshapedSource;
5312 Value sourceTensor = newOperands[0];
5316 rewriter, sourceTensor.
getType(), op.getMixedTiles());
5322 UnPackOp newOp = rewriter.
create<UnPackOp>(
5323 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5324 newMixedTileSizes, op.getOuterDimsPerm());
5328 Value oldResult = op.getResult();
5329 Value newResult = newOp.getResult();
5331 ? rewriter.
create<tensor::CastOp>(
5332 op->getLoc(), oldResult.
getType(), newResult)
5348 void LinalgDialect::getCanonicalizationPatterns(
5357 return arith::ConstantOp::materialize(builder, value, type, loc);
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
SmallVector< int64_t > outerDimsPerm
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
SmallVector< OpFoldResult > innerTiles
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static Operation * findPayloadOp(Block *body, bool initFirst=false)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
ElementwiseArityGroup arityGroup
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
SmallVector< int64_t > innerDimsPos
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
BlockArgListType getArguments()
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
This is a value defined by a result of an operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
result_iterator result_begin()
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
unsigned getNumOperands()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Kind
An enumeration of the kinds of predicates.
DynamicAPInt floor(const Fraction &f)
DynamicAPInt ceil(const Fraction &f)
DynamicAPInt round(const Fraction &f)
Fraction abs(const Fraction &f)
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override