40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
61 auto type = cast<ShapedType>(v.
getType());
62 if (!type.isDynamicDim(dim))
67 .Case([&](RankedTensorType t) ->
Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
70 .Case([&](MemRefType t) ->
Value {
71 return memref::DimOp::create(builder, loc, v, dim);
82 .Case([&](RankedTensorType t) ->
Operation * {
83 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
86 .Case([&](MemRefType type) ->
Operation * {
87 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable(
"Expected MemRefType or TensorType");
108 auto shapedType = llvm::cast<ShapedType>(source.
getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
111 return b.getIndexAttr(shapedType.getDimSize(dim));
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
192 std::optional<TypeRange> resultTensorTypes,
199 return attr.
getName() ==
"indexing_maps";
202 indexingMapsAttrVal = llvm::map_to_vector(
205 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
208 attributes, regionBuilder);
212 std::optional<TypeRange> resultTensorTypes,
219 return attr.
getName() ==
"indexing_maps";
222 indexingMapsAttrVal = llvm::map_to_vector(
225 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
228 attributes, regionBuilder);
232 std::optional<TypeRange> resultTensorTypes,
239 indexingMapsAttrVal =
241 return AffineMapAttr::get(map);
243 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
245 attributes, regionBuilder);
254 bool addOperandSegmentSizes =
true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
284 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
286 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
290 if (addOperandSegmentSizes) {
297 if (
result.propertiesAttr) {
299 attrs.
append(
"operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
305 result.addAttribute(
"operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
311 if (!
result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
315 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 <<
"'" << result.name.getStringRef() <<
"' op ";
328 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
329 if (!outputs.empty())
330 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
344 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
352 opBuilder, region, inputTypes, outputTypes, attrs,
371 unsigned numRegionArgs,
388 result.addTypes(outputTensorsTypes);
390 std::unique_ptr<Region> region = std::make_unique<Region>();
392 outputTypes,
result.attributes.getAttrs(),
395 result.addRegion(std::move(region));
402 if (resultTypes.empty())
447class RegionBuilderHelper {
449 RegionBuilderHelper(OpBuilder &builder,
Block &block)
450 : builder(builder), block(block) {}
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
455 if (!isFloatingPoint(arg)) {
457 emitError() <<
"unsupported non numeric type";
460 llvm_unreachable(
"unsupported non numeric type");
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
466 return math::ExpOp::create(builder, arg.
getLoc(), arg);
468 return math::LogOp::create(builder, arg.
getLoc(), arg);
470 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
472 return math::CeilOp::create(builder, arg.
getLoc(), arg);
474 return math::FloorOp::create(builder, arg.
getLoc(), arg);
476 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.
getType());
479 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
484 return math::RoundOp::create(builder, arg.
getLoc(), arg);
486 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
488 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
492 return math::TanhOp::create(builder, arg.
getLoc(), arg);
494 return math::ErfOp::create(builder, arg.
getLoc(), arg);
497 emitError() <<
"unsupported unary function";
500 llvm_unreachable(
"unsupported unary function");
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
514 if (!allComplex && !allFloatingPoint && !allInteger) {
517 <<
"Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
522 llvm_unreachable(
"unsupported non numeric type");
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
529 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
533 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
537 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 emitError() <<
"unsupported operation: sub with bools";
545 llvm_unreachable(
"unsupported operation: sub with bools");
547 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
554 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
558 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
563 emitError() <<
"unsupported operation: div with bools";
566 llvm_unreachable(
"unsupported operation: div with bools");
568 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
572 emitError() <<
"unsupported operation: unsigned div not on uint";
575 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
577 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
590 if (!allInteger || allBool) {
592 emitError() <<
"unsupported operation: unsigned max not on uint";
595 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
597 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
600 if (!allInteger || allBool) {
602 emitError() <<
"unsupported operation: unsigned min not on uint";
605 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
607 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
613 emitError() <<
"unsupported binary function";
616 llvm_unreachable(
"unsupported binary function");
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
629 emitError() <<
"unsupported ternary function";
632 llvm_unreachable(
"unsupported ternary function");
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
639 case TypeFn::cast_signed:
640 return cast(toType, operand,
false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand,
true);
645 emitError() <<
"unsupported type conversion function";
648 llvm_unreachable(
"unsupported type conversion function");
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
658 Value constant(
const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr =
parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
673 Type getIntegerType(
unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
677 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
685 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.
getLoc();
689 if (isa<UnknownLoc>(loc)) {
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.
getType());
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.
getType());
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.
getType());
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter)
const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
727 if (copyOp.hasPureBufferSemantics())
730 rewriter.
replaceOp(copyOp, copyOp.getInputs());
740 results.
add<EraseSelfCopy>(context);
753template <
typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter)
const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
783struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter)
const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
801 padOp,
"failed to reify tensor.pad op result shape");
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
807 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
810 if (
replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
822struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter)
const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
836 Value firstDest = insertOp.getDest();
837 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
843 bool disjoint =
false;
844 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
867 firstDest = prevOp.getDest();
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
885 Location loc = insertOp.getLoc();
888 AffineExpr sym0, sym1;
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
908 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
920struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter)
const override {
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
933 Value extractedScalar = fillOp.getInputs()[0];
936 rewriter.
replaceOp(extractOp, extractedScalar);
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
950 if (
auto paddingValue = packOp.getPaddingValue())
954 Value packOpDest = packOp.getDest();
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
963struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter)
const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
973 rewriter.
replaceOp(packOp, fillOp.value().result());
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter)
const override {
984 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
987 copyOp.getOutputs());
990 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
992 fillOp.getOutputs());
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter)
const override {
1005 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1017struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter)
const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1032 OpFoldResult firstFillVal =
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1038 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1044 OpFoldResult fillVal =
1046 if (fillVal != firstFillVal)
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1055 concatOp,
"not all operands are defined by a compatible fill op");
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1061 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1087 for (
ValueRange container : {inputs, outputs}) {
1088 for (
Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1092 blockArgLocs.push_back(v.getLoc());
1098 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1102void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1104 for (Value v : getRegionInputArgs())
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v,
"out");
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1121 inputs, outputs, bodyBuild);
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder,
result, resultTensorTypes, inputs, outputs,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1139 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1143void GenericOp::build(
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1149 ArrayRef<NamedAttribute> attributes) {
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1154void GenericOp::build(
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1162 "", bodyBuild, attributes);
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1174 "", bodyBuild, attributes);
1177void GenericOp::print(OpAsmPrinter &p) {
1181 auto genericAttrNames = linalgTraitAttrNames();
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (
auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(
getContext(), iteratorTypeNames));
1203 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1209 p << genericDictAttr;
1215 genericAttrNames.push_back(
"operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1218 bool hasExtraAttrs =
false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1223 if (hasExtraAttrs) {
1230 if (!getRegion().empty()) {
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1240 DictionaryAttr dictAttr;
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1257 if (!iteratorTypes) {
1258 return parser.
emitError(attributeLocation)
1259 <<
"expected " << getIteratorTypesAttrName(
result.name)
1260 <<
" array attribute";
1263 SmallVector<Attribute> iteratorTypeAttrs;
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1269 <<
"unexpected iterator_type (" << s <<
")";
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1274 result.attributes.set(getIteratorTypesAttrName(
result.name),
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1291 result.addRegion(std::move(region));
1297 SmallVector<Type, 1> outputTensorsTypes;
1300 result.addTypes(outputTensorsTypes);
1308 LinalgOp linalgOp) {
1309 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.
getType()))
1312 effects.emplace_back(
1317 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1341 if (!linalgOp.hasPureTensorSemantics())
1359template <
typename OpTy>
1360struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter)
const override {
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1384 linalgOp,
"expected single input and output to be the same value");
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1390 "cannot fold fill-like op");
1397 if (!linalgOp.hasPureTensorSemantics()) {
1399 linalgOp,
"mixed semantics is not supported yet");
1404 SmallVector<Value> returnedArgs;
1405 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1414 Type returnType = returnedArg.
getType();
1415 if (returnType != resultType) {
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1430 returnedArgs.push_back(returnedArg);
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1435 rewriter.
replaceOp(linalgOp, returnedArgs);
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1466 for (
Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1472 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1481void MapOp::getAsmBlockArgumentNames(Region ®ion,
1483 for (Value v : getRegionInputArgs())
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v,
"init");
1489void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(),
"mapped");
1497 ArrayRef<NamedAttribute> attributes) {
1499 result.addAttributes(attributes);
1502 Type initType = init.
getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1508 inputs, {init}, bodyBuild);
1515 bool initFirst =
false,
bool mapInit =
true) {
1519 b.setInsertionPointToStart(&block);
1520 for (
auto &operand : operands) {
1522 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1530 payloadOpOperands.push_back(block.
getArguments().back());
1531 for (
const auto &arg : block.
getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1552 if (
failed(operationName))
1556 payloadOpName = operationName.value();
1564 if (payloadOpName.has_value()) {
1565 if (!
result.operands.empty())
1567 payloadOpAttrs, ArrayRef(
result.operands),
false,
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1577 Region *body =
result.addRegion();
1585 bool mapInit =
true) {
1587 if (initFirst && !mapInit)
1611 for (
const auto &[operand, bbArg] :
1613 if (bbArg != operand)
1617 for (
const auto &[operand, bbArg] :
1620 if (bbArg != operand)
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1634 std::string attrToElide;
1636 for (
const auto &attr : payloadOp->
getAttrs()) {
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1660 if (!useShortForm) {
1666 [&](
auto arg) { p.printRegionArgument(arg); });
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() <<
"expects number of operands to match the arity of "
1683 << getInputs().size() + 1 <<
" and "
1684 << blockArgs.size();
1687 for (
const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() <<
"expected element type of input " << inputElemType
1693 <<
" to match bbArg type " << bbArgType;
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType :
TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() <<
"expected shape of input (" << inputElemShape
1703 <<
") to match shape of output (" << outputShape
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1738void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1740 for (Value v : getRegionInputArgs())
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v,
"init");
1746void ReduceOp::getAsmResultNames(
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(),
"reduced");
1752void ReduceOp::build(
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1756 ArrayRef<NamedAttribute> attributes) {
1758 result.addAttributes(attributes);
1761 for (Value init : inits) {
1762 Type initType = init.
getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1769 inputs, inits, bodyBuild);
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1774 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1784 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1788 AffineMap resultMap =
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1808 StringRef attributeName) {
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1821 if (
failed(operationName))
1825 payloadOpName = operationName.value();
1831 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1836 if (payloadOpName.has_value()) {
1838 ArrayRef(
result.operands),
true);
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1846 Region *body =
result.addRegion();
1856 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1869 if (!useShortForm) {
1875 [&](
auto arg) { p.printRegionArgument(arg); });
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1886 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1889 return emitOpError() <<
"expects all inputs to have the same shapes. "
1890 "Shape at input-index "
1892 <<
" is not equal to the shape at input-index 0.";
1895 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1898 return emitOpError() <<
"expects all outputs to have the same shapes. "
1899 "Shape at output-index "
1901 <<
" is not equal to the shape at output-index 0.";
1904 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1905 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1908 for (int64_t dimension : dimensionsRef) {
1909 if (dimension < 0 || dimension >= inputType.getRank()) {
1911 <<
"dimensions for reduction should be in the range [0, "
1912 << inputType.getRank() - 1 <<
"].";
1914 dimensionsToReduce.insert(dimension);
1917 auto inputDims = inputType.getShape();
1918 auto initDims = initType.getShape();
1921 SmallVector<int64_t> reducedInputDims;
1922 for (
const auto &en : llvm::enumerate(inputDims)) {
1923 if (!dimensionsToReduce.count(en.index()))
1924 reducedInputDims.push_back(en.value());
1927 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1928 return emitOpError() <<
"number of dimensions after reduction "
1929 << reducedInputDims.size()
1930 <<
" doesn't match the init rank "
1931 << initType.getRank();
1934 if (reducedInputDims != initDims)
1935 return emitOpError() <<
"init dimensions [" << initDims
1936 <<
"] doesn't match input dimensions after reduction ["
1937 << reducedInputDims <<
"]";
1939 Block *block = getBody();
1942 <<
"mismatching number of operands and block arguments";
1945 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1946 Type inputElementType =
1947 llvm::cast<ShapedType>(input.getType()).getElementType();
1948 if (inputElementType != bbArg.getType())
1950 <<
"input element type " << inputElementType
1951 <<
" does not match corresponding block argument type "
1956 for (
auto [output, bbArg] : llvm::zip(
1957 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1958 auto outputElementType =
1959 llvm::cast<ShapedType>(output.getType()).getElementType();
1960 if (outputElementType != bbArg.getType())
1962 <<
"output element type " << outputElementType
1963 <<
" does not match corresponding block argument type "
1979 linalg::YieldOp::create(
b, loc, args[0]);
1983void TransposeOp::build(::mlir::OpBuilder &builder,
1984 ::mlir::OperationState &
result, Value input, Value init,
1986 ArrayRef<NamedAttribute> attributes) {
1987 result.addOperands(input);
1988 result.addOperands(init);
1989 result.addAttribute(getPermutationAttrName(
result.name), permutation);
1990 result.addAttributes(attributes);
1993 Type initType = init.
getType();
1994 if (llvm::isa<RankedTensorType>(initType))
1995 result.addTypes(initType);
2001void TransposeOp::build(::mlir::OpBuilder &builder,
2002 ::mlir::OperationState &
result, Value input, Value init,
2003 ArrayRef<int64_t> permutation,
2004 ArrayRef<NamedAttribute> attributes) {
2009ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2011 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2023void TransposeOp::getAsmResultNames(
2025 if (!getResults().empty())
2026 setNameFn(getResults().front(),
"transposed");
2029void TransposeOp::print(OpAsmPrinter &p) {
2035LogicalResult TransposeOp::verify() {
2036 ArrayRef<int64_t> permutationRef = getPermutation();
2041 auto inputType = getInput().getType();
2042 auto initType = getInit().getType();
2044 int64_t rank = inputType.getRank();
2050 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2051 return emitOpError() <<
"size of permutation " << permutationRef.size()
2052 <<
" does not match the argument rank " << rank;
2054 auto inputDims = inputType.getShape();
2055 auto initDims = initType.getShape();
2057 for (int64_t i = 0; i < rank; ++i) {
2058 int64_t inputDim = inputDims[permutationRef[i]];
2059 int64_t initDim = initDims[i];
2061 if (inputDim != initDim) {
2062 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2063 <<
" doesn't match dim(input, permutation[" << i
2064 <<
"]) = " << inputDim;
2071SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2072 int64_t rank = getInit().getType().getRank();
2073 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2076ArrayAttr TransposeOp::getIndexingMaps() {
2078 int64_t rank = getInit().getType().getRank();
2081 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2085void TransposeOp::getEffects(
2086 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2095LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2096 SmallVectorImpl<OpFoldResult> &
result) {
2098 if (!isa<TensorType>(getInput().
getType()))
2102 if (getPermutation().empty()) {
2103 result.push_back(getInput());
2108 result.push_back(getInput());
2121 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2122 if (!defTransposeOp)
2127 foldedPerms.reserve(perms.size());
2129 foldedPerms.push_back(defPerms[perm]);
2132 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2146 Value input = transposeOp.getInput();
2147 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2158 unsigned dimensionSize = dimensions.size();
2159 for (
unsigned i = 0; i < dimensionSize; ++i)
2160 resultDimensions.push_back(invertPerm[dimensions[i]]);
2163 Value broadcastInput = broadcastOp.getInput();
2164 Location loc = transposeOp.getLoc();
2167 auto broadcastInputTy =
2168 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2169 unsigned inputRank = broadcastInputTy.getRank();
2170 for (
unsigned i = 0; i < inputRank; ++i) {
2171 if (broadcastInputTy.isDynamicDim(i)) {
2172 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2175 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2176 broadcastInputTy.getDimSize(i)));
2181 Value transposeInit = tensor::EmptyOp::create(
2182 rewriter, transposeOp.getLoc(), transposeResultShapes,
2183 broadcastInputTy.getElementType());
2186 Value transposeResult =
2187 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2188 transposeInit, resultPerms)
2191 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2196void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2197 MLIRContext *context) {
2198 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2205void BroadcastOp::build(::mlir::OpBuilder &builder,
2206 ::mlir::OperationState &
result, Value input, Value init,
2208 ArrayRef<NamedAttribute> attributes) {
2209 result.addOperands(input);
2210 result.addOperands(init);
2211 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2212 result.addAttributes(attributes);
2215 Type initType = init.
getType();
2216 if (llvm::isa<RankedTensorType>(initType))
2217 result.addTypes(initType);
2223void BroadcastOp::build(::mlir::OpBuilder &builder,
2224 ::mlir::OperationState &
result, Value input, Value init,
2225 ArrayRef<int64_t> dimensions,
2226 ArrayRef<NamedAttribute> attributes) {
2231ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2233 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2245void BroadcastOp::getAsmResultNames(
2247 if (!getResults().empty())
2248 setNameFn(getResults().front(),
"broadcasted");
2251void BroadcastOp::print(OpAsmPrinter &p) {
2257LogicalResult BroadcastOp::verify() {
2258 ArrayRef<int64_t> dimensionsRef = getDimensions();
2260 auto inputType = getInput().getType();
2261 auto initType = getInit().getType();
2263 int64_t inputRank = inputType.getRank();
2264 int64_t initRank = initType.getRank();
2266 auto inputShape = inputType.getShape();
2267 auto initShape = initType.getShape();
2269 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2270 return emitOpError() <<
"input rank plus added dimensions does not "
2271 "match init rank. input rank: "
2273 <<
", dimensions size: " << dimensionsRef.size()
2274 <<
", init rank: " << initRank;
2276 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2277 if (dim < 0 || dim >= initRank)
2279 <<
" is out of range. expected range: [0, "
2280 << initRank - 1 <<
"], got: " << dim;
2284 SmallVector<int64_t> dimMap;
2285 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2286 if (!llvm::is_contained(dimensionsRef, dim))
2287 dimMap.push_back(dim);
2290 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2293 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2294 return emitOpError() <<
"input dim " << inputDimIdx
2295 <<
" should match init dim " << initDimIdx
2296 <<
". input: " << inputShape[inputDimIdx]
2297 <<
", init: " << initShape[initDimIdx];
2303SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2304 int64_t rank = getInit().getType().getRank();
2305 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2308ArrayAttr BroadcastOp::getIndexingMaps() {
2310 int64_t rank = getInit().getType().getRank();
2316void BroadcastOp::getEffects(
2317 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2332 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2333 if (!defBroadcastOp)
2338 Value init = broadcastOp.getInit();
2342 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2343 if (!llvm::is_contained(dimensions, dim))
2344 dimMap.push_back(dim);
2346 for (
auto dim : defDimensions)
2347 foldedDims.push_back(dimMap[dim]);
2349 llvm::sort(foldedDims);
2351 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2356void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2357 MLIRContext *context) {
2358 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2365void linalg::YieldOp::print(OpAsmPrinter &p) {
2366 if (getNumOperands() > 0)
2367 p <<
' ' << getOperands();
2369 if (getNumOperands() > 0)
2370 p <<
" : " << getOperandTypes();
2373ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2374 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2375 SmallVector<Type, 2> types;
2385static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2386 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2387 return op.emitOpError(
"expected number of yield values (")
2388 << op.getNumOperands()
2389 <<
") to match the number of inits / outs operands of the enclosing "
2390 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2392 for (
OpOperand &opOperand : op->getOpOperands()) {
2394 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2396 if (isa<MemRefType, RankedTensorType>(elementType))
2398 if (opOperand.get().getType() != elementType)
2399 return op.emitOpError(
"type of yield operand ")
2400 << (opOperand.getOperandNumber() + 1) <<
" ("
2401 << opOperand.get().getType() <<
") doesn't match "
2402 <<
"the element type of the enclosing linalg.generic op ("
2403 << elementType <<
")";
2408LogicalResult linalg::YieldOp::verify() {
2409 auto *parentOp = (*this)->getParentOp();
2410 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2411 return emitOpError(
"expected single non-empty parent region");
2413 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2416 return emitOpError(
"expected parent op with LinalgOp interface");
2423LogicalResult IndexOp::verify() {
2424 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2426 return emitOpError(
"expected parent op with LinalgOp interface");
2427 if (linalgOp.getNumLoops() <= getDim())
2429 << getDim() <<
") to be lower than the number of loops ("
2430 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2434OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2435 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2440 return OpFoldResult{};
2443 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2444 uint64_t dim = getDim();
2445 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2446 if (loopBounds[dim] == 1)
2447 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2449 return OpFoldResult{};
2454#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2456#define GET_OP_CLASSES
2457#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2459#define GET_OP_CLASSES
2460#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2461#define GET_OP_CLASSES
2462#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2479 for (
unsigned i = 0; i < num; ++i)
2486 auto rangeA = llvm::make_range(a.begin(), a.end());
2487 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2488 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2489 return llvm::to_vector<4>(concatRanges);
2493 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2495 for (
auto size :
memref.getShape())
2502 if (
auto as =
memref.getMemorySpace()) {
2503 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2504 ss <<
"as" << attr.getInt();
2510 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2513 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2526 assert(isa<LinalgOp>(op));
2528 std::string fun =
"";
2530 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2531 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2532 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2533 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2537 llvm::replace(name,
'.',
'_');
2538 llvm::raw_string_ostream ss(name);
2542 return std::string();
2557 LogicalResult matchAndRewrite(LinalgOp op,
2559 for (
OpOperand &opOperand : op->getOpOperands()) {
2563 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2566 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2577struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2578 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2580 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2581 PatternRewriter &rewriter)
const override {
2585 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2592 if (castOp->getBlock() != linalgOp->getBlock())
2595 OpBuilder::InsertionGuard guard(rewriter);
2598 Location loc = linalgOp.getLoc();
2599 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2602 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2608 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2610 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2611 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2612 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2613 linalgOp.getDpsInits().end());
2614 outputOperands[resultNumber] = newOperand;
2615 newOperands.append(outputOperands.begin(), outputOperands.end());
2617 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2618 linalgOp->result_type_end());
2619 resultTypes[resultNumber] = resultType;
2620 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2623 Value castBack = tensor::CastOp::create(
2627 results[resultNumber] = castBack;
2636static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2637 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2638 for (OpOperand &opOperand : operands) {
2639 if (linalgOp.isScalar(&opOperand))
2641 Value src = opOperand.get();
2642 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2643 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2649 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2651 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2652 Value castSource = castOp.getSource();
2653 auto castSourceType =
2654 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2655 if (castSourceType && castSourceType.hasStaticShape())
2656 sourceShape = castSourceType.getShape();
2662 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2663 if (sourceType.isDynamicDim(i))
2665 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2666 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2676static void createNewOperandWithStaticSizes(
2677 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2678 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2679 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2680 bool &changeNeeded) {
2681 Value src = opOperand->
get();
2682 newOperands.push_back(src);
2683 if (linalgOp.isScalar(opOperand))
2685 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2686 Type resultType = sourceType;
2687 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2688 resultTypes.push_back(resultType);
2691 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2692 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2693 SmallVector<int64_t> newShape;
2696 bool newOperandNeeded =
false;
2697 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2698 int64_t dimShape = sourceShape[i];
2699 AffineExpr dimExpr = sourceMap.
getResult(i);
2700 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2701 newShape.push_back(dimShape);
2707 newShape.push_back(affineExprToSize[dimExpr]);
2708 newOperandNeeded =
true;
2710 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2711 sourceType.getEncoding());
2712 if (newOperandNeeded) {
2713 changeNeeded =
true;
2716 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2718 newOperands[index] = newOperand;
2720 if (linalgOp.isDpsInit(opOperand))
2721 resultTypes.push_back(resultType);
2727struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2728 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2730 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2731 PatternRewriter &rewriter)
const override {
2732 if (!linalgOp.hasPureTensorSemantics())
2736 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2737 return !map.isProjectedPermutation();
2742 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2743 Location loc = linalgOp.getLoc();
2747 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2749 SmallVector<Value> newOperands;
2750 SmallVector<Type> resultTypes;
2754 bool changeNeeded =
false;
2755 newOperands.reserve(linalgOp->getNumOperands());
2756 resultTypes.reserve(linalgOp.getNumDpsInits());
2759 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2760 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2761 affineExprToSize, linalgOp, newOperands,
2762 resultTypes, changeNeeded);
2771 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2772 SmallVector<Value> replacements;
2774 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2775 Value newResult = std::get<1>(it);
2776 Value oldResult = std::get<0>(it);
2777 Type newType = newResult.
getType();
2778 Type oldType = oldResult.
getType();
2779 replacements.push_back(
2780 (newType != oldType)
2781 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2784 rewriter.
replaceOp(linalgOp, replacements);
2798LogicalResult SoftmaxOp::verify() {
2799 ShapedType inputType = getInputOperandType();
2800 ShapedType outputType = getOutputOperandType();
2802 ArrayRef<int64_t> inputShape = inputType.getShape();
2803 ArrayRef<int64_t> outputShape = outputType.getShape();
2807 int64_t inputRank = getInputOperandRank();
2808 int64_t dimension = getDimension();
2809 if ((dimension < 0) || (dimension >= inputRank))
2810 return emitOpError(
"incorrect dimension specified");
2815SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2816 int64_t operandRank = getInputOperandRank();
2817 SmallVector<Range> loopBounds(operandRank);
2818 Location loc = getLoc();
2821 Value source = getInput();
2822 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2823 loopBounds[dim].offset = zero;
2824 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2825 loopBounds[dim].stride = one;
2830SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2831 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2832 utils::IteratorType::parallel);
2833 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2834 return iteratorTypes;
2837FailureOr<TilingResult>
2838SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2839 ArrayRef<OpFoldResult> offsets,
2840 ArrayRef<OpFoldResult> sizes) {
2841 int64_t rank = getInputOperandRank();
2843 SmallVector<OpFoldResult> strides(rank, oneAttr);
2844 SmallVector<Value> tiledOperands;
2845 Operation *inputSlice =
2846 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2848 return emitOpError(
"failed to compute input slice");
2850 tiledOperands.emplace_back(inputSlice->
getResult(0));
2851 Operation *outputSlice =
2852 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2854 return emitOpError(
"failed to compute output slice");
2856 tiledOperands.emplace_back(outputSlice->
getResult(0));
2858 SmallVector<Type, 4> resultTypes;
2859 if (hasPureTensorSemantics())
2860 resultTypes.push_back(tiledOperands[1].
getType());
2861 Operation *tiledOp =
2862 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2864 return TilingResult{
2867 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2870LogicalResult SoftmaxOp::getResultTilePosition(
2871 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2872 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2873 SmallVector<OpFoldResult> &resultSizes) {
2874 if (resultNumber == 0) {
2875 resultOffsets.assign(offsets.begin(), offsets.end());
2876 resultSizes.assign(sizes.begin(), sizes.end());
2883LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2888SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2890 SmallVector<OpFoldResult> shapes;
2891 Location loc = getOperation()->getLoc();
2892 IRRewriter rewriter(
b);
2893 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2894 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2895 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2896 if (!outputShapedType.isDynamicDim(dim)) {
2898 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2905 reifiedReturnShapes.emplace_back(std::move(shapes));
2909void SoftmaxOp::getEffects(
2910 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2912 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2913 if (!llvm::isa<MemRefType>(operand.
getType()))
2916 &getOperation()->getOpOperand(index), 0,
2921 for (OpOperand &operand : getDpsInitsMutable()) {
2922 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2953static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2955 int64_t dim,
bool allParallel =
false) {
2957 utils::IteratorType::parallel);
2959 iteratorTypes[dim] = utils::IteratorType::reduction;
2963 for (
int i = 0; i < inputRank; i++) {
2970 return std::make_tuple(iteratorTypes, indexingMaps);
2975template <
typename T>
2978 auto inputType = cast<ShapedType>(input.
getType());
2980 int64_t inputRank = inputShape.size();
2981 auto [iteratorTypes, indexingMaps] =
2983 assert(indexingMaps.size() == 2 &&
2984 "We should have two maps: 1 for the input, 1 for the output");
2985 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2987 auto genericOp = linalg::GenericOp::create(
2988 builder, loc, output.
getType(), input, output, indexingMaps,
2990 Value result = T::create(b, loc, args[0], args[1]);
2991 linalg::YieldOp::create(b, loc, result);
2993 return genericOp.getResult(0);
3001 auto inputType = cast<ShapedType>(input.
getType());
3003 int64_t inputRank = inputShape.size();
3005 builder, inputRank, dim,
true);
3006 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3007 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3009 indexingMaps.push_back(indexingMaps[0]);
3010 auto genericOp = linalg::GenericOp::create(
3012 indexingMaps, iteratorTypes,
3014 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3015 Value result = math::ExpOp::create(b, loc, diff);
3016 linalg::YieldOp::create(b, loc, result);
3018 return genericOp.getResult(0);
3028 auto inputType = cast<ShapedType>(numerator.
getType());
3030 int64_t inputRank = inputShape.size();
3032 builder, inputRank, dim,
true);
3033 assert(indexingMaps.size() == 2 &&
3034 "We should have one map for each input (2)");
3035 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3037 indexingMaps.push_back(indexingMaps[0]);
3038 auto genericOp = linalg::GenericOp::create(
3040 output, indexingMaps, iteratorTypes,
3042 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3043 linalg::YieldOp::create(b, loc, result);
3045 return genericOp.getResult(0);
3067FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3068 OpBuilder::InsertionGuard guard(
b);
3069 b.setInsertionPoint(*
this);
3070 Location loc = getLoc();
3071 Value input = getInput();
3072 ShapedType inputType = getInputOperandType();
3073 Type elementType = inputType.getElementType();
3074 int64_t reductionDim = getDimension();
3076 Value output = getOutput();
3077 dims.erase(dims.begin() + reductionDim);
3079 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3081 elementType,
b, loc,
3083 Value neutralForMaxFInit =
3084 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3096 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3102 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3103 return SmallVector<Value>{
result};
3110LogicalResult WinogradFilterTransformOp::verify() {
3111 auto filterType = cast<ShapedType>(getFilter().
getType());
3112 ArrayRef<int64_t> filterShape = filterType.getShape();
3113 int64_t filterH = filterShape[getFilterHDim()];
3114 int64_t filterW = filterShape[getFilterWDim()];
3115 WinogradConv2DFmr fmr = getFmr();
3119 if (filterH != r && filterH != 1)
3120 return emitOpError(
"expect filter height either equals to r or 1");
3121 if (filterW != r && filterW != 1)
3122 return emitOpError(
"expect filter width either equals to r or 1");
3123 if (filterH == 1 && filterW == 1)
3124 return emitOpError(
"expect either filter height or width equals to r");
3126 SmallVector<int64_t> expectedOutputShape;
3127 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3128 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3130 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3132 auto outputType = cast<ShapedType>(getOutput().
getType());
3133 ArrayRef<int64_t> outputShape = outputType.getShape();
3135 return emitOpError(
"the output shape is not expected");
3141WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3142 Location loc = getLoc();
3145 Value filter = getFilter();
3146 int64_t filterRank = getFilterOperandRank();
3147 SmallVector<Range> loopBounds(filterRank);
3148 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3149 loopBounds[dim].offset = zeroAttr;
3150 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3151 loopBounds[dim].stride = oneAttr;
3156SmallVector<utils::IteratorType>
3157WinogradFilterTransformOp::getLoopIteratorTypes() {
3158 int64_t filterRank = getFilterOperandRank();
3159 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3160 utils::IteratorType::parallel);
3161 return iteratorTypes;
3164LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3165 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3166 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3167 SmallVector<OpFoldResult> &resultSizes) {
3169 ShapedType filterType = getFilterOperandType();
3170 ArrayRef<int64_t> filterShape = filterType.getShape();
3171 int64_t filterH = filterShape[getFilterHDim()];
3172 int64_t filterW = filterShape[getFilterWDim()];
3173 WinogradConv2DFmr fmr = getFmr();
3176 int64_t alpha = m + r - 1;
3177 int64_t alphaH = filterH != 1 ? alpha : 1;
3178 int64_t alphaW = filterW != 1 ? alpha : 1;
3182 resultOffsets.append(
3183 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3185 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3196FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3197 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3198 ArrayRef<OpFoldResult> sizes) {
3201 ShapedType filterType = getFilterOperandType();
3202 ArrayRef<int64_t> filterShape = filterType.getShape();
3203 int64_t filterH = filterShape[getFilterHDim()];
3204 int64_t filterW = filterShape[getFilterWDim()];
3207 SmallVector<Value> tiledOperands;
3208 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3210 sliceOffsets.append(
3211 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3212 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3213 sizes[getFilterCDim()]});
3214 int64_t filterRank = getFilterOperandRank();
3215 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3216 Location loc = getLoc();
3217 auto filterSlice = tensor::ExtractSliceOp::create(
3218 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3219 tiledOperands.emplace_back(filterSlice);
3221 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3226 int64_t outputRank = getOutputOperandRank();
3227 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3228 auto outputSlice = tensor::ExtractSliceOp::create(
3229 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3230 tiledOperands.emplace_back(outputSlice);
3232 SmallVector<Type> resultTypes;
3233 resultTypes.push_back(tiledOperands[1].
getType());
3234 Operation *tiledOp =
3235 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3237 return TilingResult{
3240 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3247LogicalResult WinogradInputTransformOp::verify() {
3248 auto inputType = cast<ShapedType>(getInput().
getType());
3249 ArrayRef<int64_t> inputShape = inputType.getShape();
3250 int64_t inputH = inputShape[getInputHDim()];
3251 int64_t inputW = inputShape[getInputWDim()];
3252 WinogradConv2DFmr fmr = getFmr();
3255 int64_t tileSize = m + r - 1;
3257 auto outputType = cast<ShapedType>(getOutput().
getType());
3258 ArrayRef<int64_t> outputShape = outputType.getShape();
3259 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3260 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3262 SmallVector<int64_t> expectedOutputShape(6, inputH);
3263 if (ShapedType::isDynamic(inputH)) {
3264 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3265 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3267 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3268 expectedOutputShape[getOutputTileHDim()] =
3269 leftTransform ? (inputH - (r - 1)) / m : inputH;
3271 if (ShapedType::isDynamic(inputW)) {
3272 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3273 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3275 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3276 expectedOutputShape[getOutputTileWDim()] =
3277 rightTransform ? (inputW - (r - 1)) / m : inputW;
3279 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3280 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3283 return emitOpError(
"the output shape is not expected");
3289WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3290 Location loc = getLoc();
3293 Value output = getOutput();
3294 int64_t outputRank = getOutputOperandRank();
3295 SmallVector<Range> loopBounds(outputRank);
3296 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3297 loopBounds[dim].offset = zeroAttr;
3299 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3300 loopBounds[dim].stride = oneAttr;
3305SmallVector<utils::IteratorType>
3306WinogradInputTransformOp::getLoopIteratorTypes() {
3307 int64_t outputRank = getOutputOperandRank();
3308 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3309 utils::IteratorType::parallel);
3310 return iteratorTypes;
3313LogicalResult WinogradInputTransformOp::getResultTilePosition(
3314 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3315 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3316 SmallVector<OpFoldResult> &resultSizes) {
3318 ShapedType outputType = getOutputOperandType();
3319 ArrayRef<int64_t> outputShape = outputType.getShape();
3320 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3321 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3323 WinogradConv2DFmr fmr = getFmr();
3326 int64_t alpha = m + r - 1;
3327 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3328 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3333 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3334 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3335 offsets[getOutputCDim()]});
3336 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3337 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3338 sizes[getOutputCDim()]});
3349FailureOr<TilingResult>
3350WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3351 ArrayRef<OpFoldResult> offsets,
3352 ArrayRef<OpFoldResult> sizes) {
3354 WinogradConv2DFmr fmr = getFmr();
3358 ShapedType outputType = getOutputOperandType();
3359 ArrayRef<int64_t> outputShape = outputType.getShape();
3360 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3361 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3363 Location loc = getLoc();
3365 auto identityAffineMap =
3367 auto offsetAffineMap =
3370 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3371 offsets[getOutputTileHDim()]);
3373 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3374 offsets[getOutputTileWDim()]);
3378 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3380 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3382 SmallVector<Value> tiledOperands;
3383 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3385 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3386 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3387 sliceOffsets.append(
3388 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3389 OpFoldResult sizeH =
3390 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3391 OpFoldResult sizeW =
3392 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3394 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3395 int64_t inputRank = getInputOperandRank();
3396 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3397 auto inputSlice = tensor::ExtractSliceOp::create(
3398 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3399 tiledOperands.emplace_back(inputSlice);
3401 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3406 int64_t outputRank = getOutputOperandRank();
3407 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3408 auto outputSlice = tensor::ExtractSliceOp::create(
3409 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3410 tiledOperands.emplace_back(outputSlice);
3412 SmallVector<Type> resultTypes;
3413 resultTypes.push_back(tiledOperands[1].
getType());
3414 Operation *tiledOp =
3415 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3417 return TilingResult{
3420 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3427LogicalResult WinogradOutputTransformOp::verify() {
3428 auto valueType = cast<ShapedType>(getValue().
getType());
3429 ArrayRef<int64_t> valueShape = valueType.getShape();
3430 int64_t valueH = valueShape[getValueAlphaHDim()];
3431 int64_t valueW = valueShape[getValueAlphaWDim()];
3432 int64_t valueTileH = valueShape[getValueTileHDim()];
3433 int64_t valueTileW = valueShape[getValueTileWDim()];
3434 WinogradConv2DFmr fmr = getFmr();
3437 bool leftTransform = valueH != 1;
3438 bool rightTransform = valueW != 1;
3440 int64_t outputRank = getOutputOperandRank();
3441 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3442 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3443 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3445 if (valueH != (leftTransform ? m + r - 1 : 1))
3446 return emitOpError(
"expect input height equals to input tile size");
3447 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3449 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3450 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3452 if (valueW != (rightTransform ? m + r - 1 : 1))
3453 return emitOpError(
"expect input width equals to input tile size");
3454 expectedOutputShape[getOutputWDim()] =
3455 (rightTransform ? m : 1) * valueTileW;
3457 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3458 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3460 auto outputType = cast<ShapedType>(getOutput().
getType());
3461 ArrayRef<int64_t> outputShape = outputType.getShape();
3463 return emitOpError(
"the output shape is not expected");
3469WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3470 Location loc = getLoc();
3473 Value value = getValue();
3474 int64_t valueRank = getValueOperandRank();
3475 SmallVector<Range> loopBounds(valueRank);
3476 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3477 loopBounds[dim].offset = zeroAttr;
3479 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3480 loopBounds[dim].stride = oneAttr;
3485SmallVector<utils::IteratorType>
3486WinogradOutputTransformOp::getLoopIteratorTypes() {
3487 int64_t valueRank = getValueOperandRank();
3488 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3489 utils::IteratorType::parallel);
3490 return iteratorTypes;
3493LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3494 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3495 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3496 SmallVector<OpFoldResult> &resultSizes) {
3497 WinogradConv2DFmr fmr = getFmr();
3501 Location loc = getLoc();
3503 auto identityAffineMap =
3508 ShapedType valueType = getValueOperandType();
3509 ArrayRef<int64_t> valueShape = valueType.getShape();
3510 int64_t valueH = valueShape[0];
3511 int64_t valueW = valueShape[1];
3513 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3514 offsets[getValueTileHDim()]);
3516 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3517 offsets[getValueTileWDim()]);
3519 builder, loc, affineMap, sizes[getValueTileHDim()]);
3521 builder, loc, affineMap, sizes[getValueTileWDim()]);
3524 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3525 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3526 OpFoldResult sizeH =
3527 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3528 OpFoldResult sizeW =
3529 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3531 resultOffsets.append(
3532 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3534 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3544FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3545 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3546 ArrayRef<OpFoldResult> sizes) {
3549 Location loc = getLoc();
3550 SmallVector<Value> tiledOperands;
3551 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3553 ShapedType valueType = getValueOperandType();
3554 ArrayRef<int64_t> valueShape = valueType.getShape();
3555 int64_t alphaH = valueShape[getValueAlphaHDim()];
3556 int64_t alphaW = valueShape[getValueAlphaWDim()];
3560 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3561 offsets[getValueTileWDim()], offsets[getValueNDim()],
3562 offsets[getValueFDim()]});
3563 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3564 sizes[getValueTileWDim()], sizes[getValueNDim()],
3565 sizes[getValueFDim()]});
3566 int64_t valueRank = getValueOperandRank();
3567 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3568 auto valueSlice = tensor::ExtractSliceOp::create(
3569 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3570 tiledOperands.emplace_back(valueSlice);
3572 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3577 int64_t outputRank = getOutputOperandRank();
3578 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3579 auto outputSlice = tensor::ExtractSliceOp::create(
3580 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3581 tiledOperands.emplace_back(outputSlice);
3583 SmallVector<Type> resultTypes;
3584 resultTypes.push_back(tiledOperands[1].
getType());
3585 Operation *tiledOp =
3586 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3588 return TilingResult{
3591 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3605 llvm::set_union(explicitSet, defaultSet);
3606 return explicitSet == defaultSet;
3626 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3628 auto opIndexingMap = opIndexingMaps[opIndex];
3629 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3632 return matmulOp->emitOpError()
3633 <<
"Unexpected dim expression in map result.";
3636 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3637 return matmulOp->emitOpError()
3638 <<
"Invalid broadcast requested, should be (d2).";
3647template <
typename OpTy>
3650 AffineMap defaultIndexingMap,
bool isLHS) {
3651 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3652 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3653 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3656 return batchVariantMatmulOp->emitOpError()
3657 <<
"Unexpected result dim expression (outside the set of default "
3662 return batchVariantMatmulOp->emitOpError()
3663 <<
"no. of result dim expressions exceeds 3.";
3665 auto hasValidBatchDim = [](
AffineMap map) {
3672 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3673 return batchVariantMatmulOp->emitOpError()
3674 <<
"Invalid broadcast requested.";
3675 }
else if (!hasValidBatchDim(opIndexingMap)) {
3676 return batchVariantMatmulOp->emitOpError()
3677 <<
"Invalid batch dimension expression.";
3685template <
typename OpTy>
3688 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3689 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3690 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3691 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3694 return batchVariantMatmulOp->emitOpError()
3695 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3698 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3700 return batchVariantMatmulOp->emitOpError()
3701 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3705 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3706 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3707 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3708 outputMap.getResult(1).isFunctionOfDim(1) &&
3709 outputMap.getResult(2).isFunctionOfDim(2)
3710 : outputMap.getResult(0).isFunctionOfDim(1) &&
3711 outputMap.getResult(1).isFunctionOfDim(2);
3714 if (!areValidOutputResultDim(opIndexingMap)) {
3715 return batchVariantMatmulOp->emitOpError()
3716 <<
"Invalid output map result dimension.";
3725template <
typename OpTy>
3730 batchVariantMatmulOp.getIndexingMapsArray();
3732 batchVariantMatmulOp.getDefaultIndexingMaps(
3733 batchVariantMatmulOp->getContext());
3735 if (opIndexingMaps.size() != 3)
3736 return batchVariantMatmulOp->emitOpError()
3737 <<
"Indexing_map attribute must have 3 affine maps.";
3739 auto opIndexingMap = opIndexingMaps[opIndex];
3740 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3748 defaultIndexingMap, opIndex == 0)))
3758 if (m == 2 && r == 3)
3759 return WinogradConv2DFmr::F_2_3;
3760 if (m == 4 && r == 3)
3761 return WinogradConv2DFmr::F_4_3;
3762 if (m == 2 && r == 5)
3763 return WinogradConv2DFmr::F_2_5;
3764 return std::nullopt;
3769 case WinogradConv2DFmr::F_2_3:
3771 case WinogradConv2DFmr::F_4_3:
3773 case WinogradConv2DFmr::F_2_5:
3776 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3783static FailureOr<SmallVector<SmallVector<int64_t>>>
3786 for (
auto map : maps) {
3787 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3791 for (
auto result : attr.getAffineMap().getResults()) {
3792 auto dim = dyn_cast<AffineDimExpr>(
result);
3795 pos.push_back(dim.getPosition());
3797 positions.push_back(pos);
3810 return indexingMaps;
3813bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3814 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3817 if (maps.size() != 3)
3822 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3823 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3824 (*positions)[2] == SmallVector<int64_t>{0, 1};
3827SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3828 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3829 utils::IteratorType::parallel,
3830 utils::IteratorType::reduction};
3833unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3835std::string MatmulOp::getLibraryCallName() {
3839bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3843bool MatmulOp::hasUserDefinedMaps() {
3844 SmallVector<AffineMap, 3> defaultMaps =
3846 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3847 return defaultMaps != explicitMaps;
3852void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3853 ArrayRef<NamedAttribute> attrs,
3856 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3861 "MatmulOp regionBuilder expects 3 args");
3862 RegionBuilderHelper helper(
b, block);
3863 SmallVector<Value> yields;
3865 TypeFn castVal = TypeFn::cast_signed;
3866 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3867 return attr.
getName() ==
"cast";
3869 if (castIter != attrs.end()) {
3870 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3878 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3879 if (!value1 || !value2 || !value3)
3881 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3885 yields.push_back(value4);
3886 helper.yieldOutputs(yields);
3896bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3897 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3898 AffineExpr expr = bcastMap.
getResult(0);
3908 ArrayAttr arrayAttr;
3912 if (llvm::any_of(arrayAttr,
3913 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3915 <<
"element of indexing_maps array is not an affine_map";
3922 if (failed(indexingMapsAttr))
3925 if (*indexingMapsAttr ==
nullptr) {
3926 auto indexingMapAttrs = llvm::map_to_vector(
3927 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3932 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3934 MatmulOp::getRegionBuilder());
3937void MatmulOp::print(OpAsmPrinter &p) {
3938 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3939 MatmulOp::getDefaultIndexingMaps(
getContext()),
3940 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3941 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3942 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3944 std::array<StringRef, 3> elidedAttrs = {
3945 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3951LogicalResult MatmulOp::verify() {
3953 if (!hasUserDefinedMaps())
3956 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3963LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3967void MatmulOp::getEffects(
3968 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3970 if (hasPureTensorSemantics())
3979SmallVector<AffineMap>
3980MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3981 AffineExpr d0, d1, d2;
3987 return {mapLHS, mapRHS, mapOut};
3991 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3994 if (maps.size() != 3)
3997 if (failed(positions))
4009 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4017 build(builder, state, inputs, outputs, attributes);
4018 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4019 assert(res &&
"builder didn't return the right type");
4029 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4038 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4039 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4040 assert(res &&
"builder didn't return the right type");
4050 result.addAttribute(
"cast", cast);
4052 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4061 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4062 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4063 assert(res &&
"builder didn't return the right type");
4068 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4070 op->
getAttr(
"indexing_maps"));
4074MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4081 return {mapLHS, mapRHS, mapOut};
4085 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4088 if (maps.size() != 3)
4091 if (failed(positions))
4103 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4111 build(builder, state, inputs, outputs, attributes);
4112 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4113 assert(res &&
"builder didn't return the right type");
4123 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4132 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4133 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4134 assert(res &&
"builder didn't return the right type");
4144 result.addAttribute(
"cast", cast);
4146 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4155 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4156 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4157 assert(res &&
"builder didn't return the right type");
4162 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4164 op->
getAttr(
"indexing_maps"));
4168BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4175 return {mapLHS, mapRHS, mapOut};
4179 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4182 if (maps.size() != 3)
4185 if (failed(positions))
4196 BatchMatmulOp::getRegionBuilder(),
4197 getDefaultIndexingMaps(builder));
4205 build(builder, state, inputs, outputs, attributes);
4206 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4207 assert(res &&
"builder didn't return the right type");
4216 BatchMatmulOp::getRegionBuilder(),
4217 getDefaultIndexingMaps(builder));
4226 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4227 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4228 assert(res &&
"builder didn't return the right type");
4236 result.addAttribute(
"cast", cast);
4238 BatchMatmulOp::getRegionBuilder(),
4239 getDefaultIndexingMaps(builder));
4248 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4249 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4250 assert(res &&
"builder didn't return the right type");
4255 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4257 op->
getAttr(
"indexing_maps"));
4261BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4268 return {mapLHS, mapRHS, mapOut};
4272 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4275 if (maps.size() != 3)
4278 if (failed(positions))
4289 BatchMatmulOp::getRegionBuilder(),
4290 getDefaultIndexingMaps(builder));
4298 build(builder, state, inputs, outputs, attributes);
4299 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4300 assert(res &&
"builder didn't return the right type");
4309 BatchMatmulOp::getRegionBuilder(),
4310 getDefaultIndexingMaps(builder));
4319 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4320 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4321 assert(res &&
"builder didn't return the right type");
4329 result.addAttribute(
"cast", cast);
4331 BatchMatmulOp::getRegionBuilder(),
4332 getDefaultIndexingMaps(builder));
4341 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4342 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4343 assert(res &&
"builder didn't return the right type");
4348 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4350 op->
getAttr(
"indexing_maps"));
4358 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4369 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4370 assert(dimExpr &&
"affine_map is a projected permutation");
4371 dimsInOutput[dimExpr.getPosition()] =
true;
4375 for (
auto dimOccursInOutput : dimsInOutput)
4376 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4377 : utils::IteratorType::reduction);
4379 return iteratorTypes;
4382unsigned ContractOp::getNumRegionArgs() {
return 3; }
4385void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4386 ArrayRef<NamedAttribute> attrs,
4389 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4394 "ContractOp regionBuilder expects 3 args");
4395 RegionBuilderHelper helper(
b, block);
4397 TypeFn castSignedness = TypeFn::cast_signed;
4398 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4399 return attr.
getName() ==
"cast";
4401 if (castIter != attrs.end()) {
4402 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4408 Value lhsAtOutType =
4409 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4410 Value rhsAtOutType =
4411 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4412 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4414 if (!productAtOutType)
4420 helper.yieldOutputs({
result});
4423ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4425 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4427 "expected 'indexing_maps' attribute");
4428 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4434void ContractOp::print(OpAsmPrinter &p) {
4435 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4437 p, getOperation(), getInputs(), getOutputs(),
4438 {
"indexing_maps",
"operandSegmentSizes"});
4441LogicalResult ContractOp::verify() {
4442 int iterationSpaceDims = -1;
4447 SmallVector<size_t> inOccurrences;
4448 SmallVector<size_t> outOccurrences;
4451 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4452 bool isInput) -> LogicalResult {
4455 return emitError(
"provided affine_map is not a projected permutation");
4458 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4460 return emitError(
"ranks of shaped operand and results of corresponding "
4461 "affine_map differ");
4463 return emitError(
"affine_map specifies shaped access while operand has "
4468 if (iterationSpaceDims == -1) {
4470 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4471 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4472 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4473 return emitError(
"iteration spaces of provided affine_maps differ");
4477 for (AffineExpr affineExpr : affineMap.
getResults()) {
4478 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4480 llvm_unreachable(
"affine_map is a projected permutation");
4483 inOccurrences[affineDimExpr.getPosition()] += 1;
4485 outOccurrences[affineDimExpr.getPosition()] += 1;
4491 for (
auto &&[affineMap, operandType, isInput] :
4492 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4493 SmallVector<bool>{
true,
true,
false})) {
4494 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4498 bool hasContractingDim =
false;
4499 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4500 size_t inOccCount = inOccurrences[dimIndex];
4501 size_t outOccCount = outOccurrences[dimIndex];
4504 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4506 if (inOccCount == 0 && outOccCount == 0)
4507 return emitError() <<
"iteration space dim at index " << dimIndex
4508 <<
" not used to access any operand";
4519 if (inOccCount == 1 && outOccCount != 1)
4521 <<
"iteration space dim at index " << dimIndex
4522 <<
" is neither a contracting dim nor of parallel iteration type";
4525 if (!hasContractingDim)
4526 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4531LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4535void ContractOp::getEffects(
4536 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4538 if (hasPureTensorSemantics())
4550SmallVector<AffineMap>
4551BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4552 AffineExpr d0, d1, d2, d3;
4553 SmallVector<AffineMap> indexingMaps;
4555 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4556 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4557 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4558 return indexingMaps;
4561bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4562 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4565 if (maps.size() != 3)
4570 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4571 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4572 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4575SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4576 return SmallVector<utils::IteratorType>{
4577 utils::IteratorType::parallel, utils::IteratorType::parallel,
4578 utils::IteratorType::parallel, utils::IteratorType::reduction};
4581unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4583std::string BatchMatmulOp::getLibraryCallName() {
4589bool BatchMatmulOp::hasUserDefinedMaps() {
4590 SmallVector<AffineMap, 3> defaultMaps =
4592 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4593 return defaultMaps != explicitMaps;
4603bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4605 "Expected less than 3 result dim expr.");
4606 bool isValid =
false;
4607 enum Indices { batchPos, mPos, nPos, kPos };
4609 AffineExpr expr = bcastMap.
getResult(0);
4612 AffineExpr expr0 = bcastMap.
getResult(0);
4613 AffineExpr expr1 = bcastMap.
getResult(1);
4618 : ((expr0.isFunctionOfDim(batchPos) &&
4619 expr1.isFunctionOfDim(kPos)) ||
4620 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4625void BatchMatmulOp::regionBuilder(
4626 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4629 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4634 "BatchMatmulOp regionBuilder expects 3 args");
4635 RegionBuilderHelper helper(
b, block);
4636 SmallVector<Value> yields;
4638 TypeFn castVal = TypeFn::cast_signed;
4639 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4640 return attr.
getName() ==
"cast";
4642 if (castIter != attrs.end()) {
4643 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4648 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4649 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4651 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4652 if (!castValA || !castValB || !mulVal)
4654 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4658 yields.push_back(addVal);
4659 helper.yieldOutputs(yields);
4662ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4663 SmallVector<Attribute, 3> indexingMapsAttr;
4675 if (!isa<AffineMapAttr>(mapAttr)) {
4677 "expected affine map attribute");
4679 indexingMapsAttr.push_back(mapAttr);
4689 if (indexingMapsAttr.empty()) {
4690 indexingMapsAttr = llvm::map_to_vector(
4691 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4692 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4694 result.addAttribute(
"indexing_maps",
4697 return ::parseNamedStructuredOp(parser,
result,
4698 BatchMatmulOp::getNumRegionArgs(),
4699 BatchMatmulOp::getRegionBuilder());
4702void BatchMatmulOp::print(OpAsmPrinter &p) {
4703 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4704 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4705 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4706 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4707 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4709 std::array<StringRef, 3> elidedAttrs = {
4710 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4716LogicalResult BatchMatmulOp::verify() {
4719 if (!hasUserDefinedMaps())
4722 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4729LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4730 SmallVectorImpl<OpFoldResult> &) {
4734void BatchMatmulOp::getEffects(
4735 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4737 if (hasPureTensorSemantics())
4751struct ArityGroupAndKind {
4753 ElementwiseArityGroup arityGroup;
4759 TernaryFn ternaryFn;
4763unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4764 return static_cast<unsigned>(arityGroup);
4769 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4770 constexpr int lastBinary =
4771 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4772 constexpr int lastTernary =
4773 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4775 int val =
static_cast<int>(kind);
4776 ArityGroupAndKind
result;
4778 if (val < lastUnary) {
4779 result.arityGroup = ElementwiseArityGroup::Unary;
4780 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4783 if (val < lastBinary) {
4784 result.arityGroup = ElementwiseArityGroup::Binary;
4785 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4788 if (val >= lastTernary) {
4789 llvm_unreachable(
"unhandled ElementwiseFn");
4791 result.arityGroup = ElementwiseArityGroup::Ternary;
4792 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4797 auto rank = getResultRank();
4802ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4808ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4811 mlir::linalg::ElementwiseKind elemwiseKindVal;
4816 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4817 if (!elemwiseKindAttr)
4819 "expected ElementwiseKind attribute");
4820 elemwiseKindVal = elemwiseKindAttr.getValue();
4823 "expected operation 'kind' attribute");
4826 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4829 SmallVector<Attribute, 3> indexingMapsAttr;
4839 if (!isa<AffineMapAttr>(mapAttr))
4841 "expected affine map attribute");
4842 indexingMapsAttr.push_back(mapAttr);
4853 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4855 ElementwiseOp::getRegionBuilder())) {
4857 "unable to parse elemwise op");
4861 if (indexingMapsAttr.empty()) {
4864 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4865 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4868 "return type needs to be shaped type");
4869 auto numDims = shapedType.getRank();
4870 indexingMapsAttr = llvm::map_to_vector(
4871 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4873 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4876 result.addAttribute(
"indexing_maps",
4881void ElementwiseOp::print(OpAsmPrinter &p) {
4884 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4888 unsigned numDims = getResultRank();
4890 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4891 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4893 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4895 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4896 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4904void ElementwiseOp::regionBuilder(
4905 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4907 ElementwiseKind elemwiseKind;
4908 for (
auto attr : attrs) {
4909 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4910 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4911 assert(kindAttr &&
"op kind attribute incorrectly set");
4912 elemwiseKind = kindAttr.getValue();
4918 auto arityGroup = groupAndKind.arityGroup;
4919 auto kind = groupAndKind.kind;
4921 getArityGroupAsUInt(arityGroup) + 1 ) {
4922 emitError() <<
"Elementwise regionBuilder expects "
4923 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4928 getArityGroupAsUInt(arityGroup) + 1
4929 &&
"Elementwise regionBuilder number of block args mismatch");
4931 RegionBuilderHelper helper(
b, block);
4932 SmallVector<Value> yields;
4935 if (arityGroup == ElementwiseArityGroup::Unary) {
4938 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4942 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4947 assert(
false &&
"found unhandled category in elemwise");
4950 yields.push_back(
result);
4951 helper.yieldOutputs(yields);
4954LogicalResult ElementwiseOp::fold(FoldAdaptor,
4955 SmallVectorImpl<OpFoldResult> &) {
4959void ElementwiseOp::getEffects(
4960 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4962 if (hasPureTensorSemantics())
4975template <
typename OpTy,
typename>
4978 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4979 ? packOrUnPack.getDestType()
4980 : packOrUnPack.getSourceType();
4981 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4982 ? packOrUnPack.getSourceType()
4983 : packOrUnPack.getDestType();
4985 packedType.getShape().take_front(unpackedType.getRank()));
4986 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5008 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5010 .take_back(mixedTiles.size()),
5013 if (
shape == ShapedType::kDynamic) {
5014 newMixedTileSizes.push_back(std::get<1>(it));
5021 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5023 newMixedTileSizes.push_back(
tile);
5026 "tile size and dim size don't match!");
5027 newMixedTileSizes.push_back(
5032 return newMixedTileSizes;
5035template <
typename OpTy>
5039 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5040 "applies to only pack or unpack operations");
5041 int64_t destRank = op.getDestRank();
5043 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5044 reifiedReturnShapes[0][dim] =
5049template <
typename OpTy>
5051 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5052 "applies to only pack or unpack operations");
5056 assert(tiles.size() == dimsToTile.size() &&
5057 "tiles must match indices of dimension to block");
5059 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5060 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5061 return dimAndTileMapping;
5064template <
typename OpTy>
5066 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5067 "applies to only pack or unpack operations");
5070 unsigned dynamicValIndex = 0;
5071 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5072 if (ShapedType::isStatic(staticTile))
5075 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5077 return mixedInnerTiles;
5080template <
typename OpTy>
5082 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5083 "applies to only pack or unpack operations");
5096 size_t dimsPosSize = dimsPos.size();
5097 if (dimsPosSize > rank)
5100 if (dimsPosSize != uniqued.size())
5102 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5103 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5107template <
typename OpTy>
5109 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5110 "applies to only pack or unpack operations");
5111 Operation *op = packOrUnPack.getOperation();
5119 if (!packOrUnPack.getSourceType().hasRank() ||
5120 !packOrUnPack.getDestType().hasRank())
5121 return op->
emitError(
"expected both source and destination to have rank");
5124 if (!packOrUnPack.hasPureBufferSemantics() &&
5125 !packOrUnPack.hasPureTensorSemantics())
5126 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5127 const unsigned numResults = packOrUnPack.getNumResults();
5128 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5129 return op->
emitError(
"expected 1 result, got ") << numResults;
5130 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5131 return op->
emitError(
"expected 0 results, got ") << numResults;
5135 if (hasZeros(mixedTiles))
5136 return op->
emitError(
"invalid zero tile factor");
5139 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5140 ? packOrUnPack.getSourceType()
5141 : packOrUnPack.getDestType();
5142 size_t unpackedRank = unpackedType.getRank();
5146 return op->
emitError(
"invalid inner_dims_pos vector");
5148 return op->
emitError(
"invalid outer_dims_perm vector");
5149 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5150 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5154 if (mixedTiles.size() > unpackedRank) {
5155 return op->
emitError(
"tiling factors must be less than or equal to the "
5156 "input rank for pack or output rank for unpack");
5158 if (mixedTiles.size() != innerDimsPos.size()) {
5160 "tiling factors must equal the number of dimensions to tile");
5163 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5164 ? packOrUnPack.getDestType()
5165 : packOrUnPack.getSourceType();
5166 size_t packedRank = packedType.getRank();
5168 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5169 if (expectedPackedRank != packedRank) {
5171 "packed rank != (unpacked rank + num tiling factors), got ")
5172 << packedRank <<
" != " << expectedPackedRank;
5179 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5180 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5182 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5184 [](std::tuple<int64_t, OpFoldResult> it) {
5185 int64_t shape = std::get<0>(it);
5186 if (Attribute attr =
5187 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5188 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5189 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5190 return shape == staticTileSize;
5192 return ShapedType::isDynamic(
shape);
5194 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
5195 "tiled dimension in the packed type");
5199 auto elementType = unpackedType.getElementType();
5200 Type expectedType, actualType;
5201 if (packOrUnPack.hasPureTensorSemantics()) {
5202 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5203 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5205 expectedType = MemRefType::get(expectedPackedShape, elementType);
5206 actualType = MemRefType::get(packedType.getShape(), elementType);
5208 return op->emitError(
"expected ")
5209 << expectedType <<
" for the packed domain value, got "
5222struct PackOrUnPackTransposeResult {
5229template <
typename OpTy>
5230static PackOrUnPackTransposeResult
5234 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5235 "applies to only pack or unpack operations");
5236 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5237 "some permutation must be non-empty");
5238 PackOrUnPackTransposeResult metadata;
5239 metadata.innerDimsPos =
5241 metadata.innerTiles =
5243 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5244 ? packOrUnPackOp.getSourceRank()
5245 : packOrUnPackOp.getDestRank();
5246 metadata.outerDimsPerm =
5247 packOrUnPackOp.getOuterDimsPerm().empty()
5248 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5250 if (!innerPermutation.empty()) {
5251 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5253 "invalid inner permutation");
5257 if (!outerPermutation.empty()) {
5258 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5260 "invalid outer permutation");
5271 if (!getResults().empty())
5272 setNameFn(getResult(),
"pack");
5282 Type sourceType, destType, resultType;
5302 if (parser.parseInteger(value))
5304 outerDimsPermVec.push_back(value);
5317 if (parser.parseInteger(value))
5319 innerDimsPosVec.push_back(value);
5331 for (
auto val : staticTilesAttr.
asArrayRef())
5332 staticTiles.push_back(val);
5349 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5352 "pack/unpack requires '->' and destination type");
5356 resultType = destType;
5362 if (!paddingValue.empty() &&
5367 if (!dynamicTiles.empty() &&
5372 result.addAttribute(
"static_inner_tiles",
5374 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5376 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5379 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5380 static_cast<int32_t
>(dynamicTiles.size())};
5381 result.addAttribute(
"operandSegmentSizes",
5385 result.addTypes(resultType);
5391 p <<
" " << getSource();
5393 if (getPaddingValue()) {
5394 p <<
" padding_value(" << getPaddingValue() <<
" : "
5395 << getPaddingValue().getType() <<
")";
5398 if (!getOuterDimsPerm().empty()) {
5399 p <<
" outer_dims_perm = [";
5400 llvm::interleaveComma(getOuterDimsPerm(), p);
5404 p <<
" inner_dims_pos = [";
5405 llvm::interleaveComma(getInnerDimsPos(), p);
5408 p <<
" inner_tiles = ";
5411 p <<
" into " << getDest();
5414 {
"static_inner_tiles",
"inner_dims_pos",
5415 "outer_dims_perm",
"operandSegmentSizes"});
5417 p <<
" : " << getSource().getType();
5418 p <<
" -> " << getDest().getType();
5424 std::optional<Value> paddingValue,
5426 assert(innerDimsPos.size() == innerTiles.size() &&
5427 "number of tile sizes specified must match the specified number of "
5428 "original dimensions to be tiled");
5432 build(builder, state, dest.
getType(), source, dest,
5433 paddingValue ? *paddingValue :
nullptr,
5434 outerDimsPerm.empty() ?
nullptr
5441PackOp::reifyResultShapes(
OpBuilder &builder,
5459 ShapedType inputType = getSourceType();
5460 int64_t inputRank = inputType.getRank();
5461 return getDestType().getShape().take_front(inputRank);
5465 auto innerDimsPos = getInnerDimsPos();
5472 if (!outerDimPermInv.empty())
5476 for (
auto index : innerDimsPos)
5477 res.push_back(outerDims[
index]);
5488 outputShape.take_front(inputShape.size()));
5489 if (!outerDimsPerm.empty()) {
5490 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5491 "expected output and outer_dims_perm to have same size");
5495 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5496 if (ShapedType::isDynamic(inputShape[pos]))
5500 if (!constantTile) {
5501 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5502 (inputShape[pos] % outputTileSizes[pos] != 0))
5504 }
else if (inputShape[pos] % (*constantTile) != 0) {
5517 outputShape.take_front(inputShape.size()));
5518 if (!outerDimsPerm.empty()) {
5519 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5520 "expected output and outer_dims_perm to have same size");
5524 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5525 if (ShapedType::isDynamic(inputShape[pos]) ||
5526 ShapedType::isDynamic(outputTileSizes[pos]))
5531 if (inputShape[pos] % (*constantTile) != 0)
5537LogicalResult PackOp::verify() {
5544 auto paddingValue = getPaddingValue();
5548 << getSourceType().getElementType()
5549 <<
" but got: " << paddingValue.getType();
5552 if (!paddingValue &&
5553 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5554 getDestType().
getShape(), getOuterDimsPerm(),
5557 "invalid tile factor or output size provided. Only full tiles are "
5558 "supported when padding_value is not set");
5568 for (
auto o : ofrs) {
5570 if (llvm::dyn_cast_if_present<Value>(o))
5571 result.push_back(ShapedType::kDynamic);
5583 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5584 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5586 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5587 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5590 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5591 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5595 if (!outerDimsPerm.empty())
5599 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5612 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5614 builder, loc, ceilDivExpr,
5615 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5617 if (!outerDimsPerm.empty())
5619 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5624 innerDimsPos, outerDimsPerm);
5630 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5631 if (ShapedType::isStatic(resultTypeShape[i]))
5640RankedTensorType PackOp::inferPackedTensorType(
5644 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5645 return RankedTensorType::get(resultShape, sourceType.getElementType());
5648MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5653 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5654 return MemRefType::get(resultShape, sourceType.getElementType());
5669 for (
auto [
index, value] : llvm::enumerate(
5670 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5671 if (ShapedType::isDynamic(value))
5672 mixedSizes.push_back(
5673 tensor::DimOp::create(
b, loc, source,
index).getResult());
5675 mixedSizes.push_back(
b.getIndexAttr(value));
5677 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5678 int64_t dimPos = std::get<0>(it);
5680 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5682 if (!outerDimsPerm.empty())
5685 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5686 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5687 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5694 *
this, innerPermutation, outerPermutation);
5695 Value transposedDest =
5696 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5697 metadata.innerDimsPos, metadata.outerDimsPerm);
5698 return PackOp::create(
b, loc, getSource(), transposedDest,
5699 metadata.innerDimsPos, metadata.innerTiles,
5700 getPaddingValue(), metadata.outerDimsPerm);
5703template <
typename OpTy>
5708 if (op.hasPureTensorSemantics())
5711 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5712 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5715 if (&opOperand == &op.getSourceMutable()) {
5719 }
else if (&opOperand == &op.getDestMutable()) {
5730void PackOp::getEffects(
5736void UnPackOp::getEffects(
5743template <
typename OpTy>
5745 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5746 "applies to only pack or unpack operations");
5747 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5749 : op.getSourceType();
5751 for (
auto [dimDest,
tile] : llvm::zip(
5752 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5754 if (!constTileSize || ShapedType::isDynamic(dimDest))
5761 if (!hasPureTensorSemantics())
5763 if (getPaddingValue())
5778 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5780 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5792 auto packTiles = packOp.getMixedTiles();
5793 auto unPackTiles = unPackOp.getMixedTiles();
5794 if (packTiles.size() != unPackTiles.size())
5796 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5805 auto srcType = op.getSourceType();
5806 if (llvm::any_of(op.getInnerDimsPos(),
5807 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5809 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5811 return !PackOp::requirePaddingValue(
5812 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5813 op.getOuterDimsPerm(), op.getMixedTiles());
5820 bool changeNeeded =
false;
5821 srcShape.assign(packOp.getSourceType().getShape().begin(),
5822 packOp.getSourceType().getShape().end());
5823 destShape.assign(packOp.getDestType().getShape().begin(),
5824 packOp.getDestType().getShape().end());
5825 llvm::SmallSetVector<int64_t, 4> innerDims;
5826 innerDims.insert_range(packOp.getInnerDimsPos());
5828 if (!packOp.getOuterDimsPerm().empty())
5830 int srcRank = packOp.getSourceRank();
5831 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5832 if (innerDims.contains(i))
5836 if (!inverseOuterDimsPerm.empty())
5837 destPos = inverseOuterDimsPerm[srcPos];
5838 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5839 ShapedType::isDynamic(destShape[destPos])) {
5842 int64_t size = srcShape[srcPos];
5843 if (ShapedType::isDynamic(size))
5844 size = destShape[destPos];
5845 srcShape[srcPos] = size;
5846 destShape[destPos] = size;
5847 changeNeeded =
true;
5849 return changeNeeded;
5852LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5854 if (!packOp.hasPureTensorSemantics())
5858 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5859 if (unPackOp.getSourceType() == packOp.getDestType() &&
5860 !packOp.getPaddingValue() &&
5863 rewriter.
replaceOp(packOp, unPackOp.getSource());
5871 packOp.getPaddingValueMutable().clear();
5880 Value source = packOp.getSource();
5881 if (srcShape != packOp.getSourceType().getShape()) {
5882 auto newSrcType = packOp.getSourceType().clone(srcShape);
5884 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5886 Value dest = packOp.getDest();
5887 ShapedType originalResultType = packOp.getDestType();
5888 bool needUpdateDestType = (destShape != originalResultType.getShape());
5889 if (needUpdateDestType) {
5890 auto newDestType = packOp.getDestType().clone(destShape);
5892 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5895 packOp.getSourceMutable().assign(source);
5896 packOp.getDestMutable().assign(dest);
5897 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5900 if (needUpdateDestType) {
5902 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5903 packOp.getResult());
5912template <
typename PackOrUnpackOp>
5914 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5915 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5916 "Function meant for pack/unpack");
5921 int64_t numPackedDims = innerDimsPos.size();
5922 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5923 if (orderedDims != innerDimsPos) {
5929 int64_t packedRank = packedTensorType.getRank();
5939 return llvm::all_of(
5940 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5941 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5944bool PackOp::isLikePad() {
5945 auto packedTensorType =
5946 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5950::mlir::LogicalResult
5951PackOp::fold(FoldAdaptor adaptor,
5953 if (!hasPureTensorSemantics())
5955 std::optional<Attribute> paddingValue;
5956 if (
auto pad = adaptor.getPaddingValue())
5958 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5959 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5960 cast<TensorType>(getDestType()), paddingValue)) {
5961 results.push_back(reshapedSource);
5987 if (!op.hasPureTensorSemantics())
6006 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6007 op.getInnerDimsPos(), newMixedTileSizes,
6008 op.getPaddingValue(), op.getOuterDimsPerm());
6009 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6012 Value oldResult = op.getResult();
6013 Value newResult = newOp.getResult();
6016 ? tensor::CastOp::create(rewriter, op->getLoc(),
6017 oldResult.
getType(), newResult)
6030void UnPackOp::getAsmResultNames(
6032 if (!getResults().empty())
6033 setNameFn(getResult(),
"unpack");
6042 Type sourceType, destType, resultType;
6054 if (parser.parseInteger(value))
6056 outerDimsPermVec.push_back(value);
6069 if (parser.parseInteger(value))
6071 innerDimsPosVec.push_back(value);
6083 for (
auto val : staticTilesAttr.
asArrayRef())
6084 staticTiles.push_back(val);
6101 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6104 "pack/unpack requires '->' and destination type");
6108 resultType = destType;
6114 if (!dynamicTiles.empty() &&
6119 result.addAttribute(
"static_inner_tiles",
6121 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6123 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6126 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6127 result.addAttribute(
"operandSegmentSizes",
6131 result.addTypes(resultType);
6137 p <<
" " << getSource();
6139 if (!getOuterDimsPerm().empty()) {
6140 p <<
" outer_dims_perm = [";
6141 llvm::interleaveComma(getOuterDimsPerm(), p);
6145 p <<
" inner_dims_pos = [";
6146 llvm::interleaveComma(getInnerDimsPos(), p);
6149 p <<
" inner_tiles = ";
6152 p <<
" into " << getDest();
6155 {
"static_inner_tiles",
"inner_dims_pos",
6156 "outer_dims_perm",
"operandSegmentSizes"});
6158 p <<
" : " << getSource().getType();
6159 p <<
" -> " << getDest().getType();
6163UnPackOp::reifyResultShapes(
OpBuilder &builder,
6181 ShapedType destType = getDestType();
6182 int64_t destRank = destType.getRank();
6183 return getSourceType().getShape().take_front(destRank);
6187 auto innerDimsPos = getInnerDimsPos();
6194 if (!outerDimPermInv.empty())
6198 for (
auto index : innerDimsPos)
6199 res.push_back(outerDims[
index]);
6204LogicalResult UnPackOp::verify() {
6209 if (!hasPureTensorSemantics())
6222 assert(innerDimsPos.size() == innerTiles.size() &&
6223 "number of tile sizes specified must match the specified number of "
6224 "original dimensions to be tiled");
6228 build(builder, state, dest.
getType(), source, dest,
6229 outerDimsPerm.empty() ?
nullptr
6247 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6249 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6250 if (srcType.isDynamicDim(i))
6251 mixedSizes.push_back(
6252 tensor::DimOp::create(
b, loc, source, i).getResult());
6254 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6256 if (!outerDimsPerm.empty()) {
6261 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6262 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6264 auto elemType = srcType.getElementType();
6265 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6269 Value transposedSource,
6273 *
this, innerPermutation, outerPermutation);
6274 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6275 metadata.innerDimsPos, metadata.innerTiles,
6276 metadata.outerDimsPerm);
6283 bool changeNeeded =
false;
6284 srcShape.assign(op.getSourceType().getShape().begin(),
6285 op.getSourceType().getShape().end());
6286 destShape.assign(op.getDestType().getShape().begin(),
6287 op.getDestType().getShape().end());
6288 llvm::SmallSetVector<int64_t, 4> innerDims;
6289 innerDims.insert_range(op.getInnerDimsPos());
6291 if (!op.getOuterDimsPerm().empty())
6293 int destRank = op.getDestRank();
6294 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6295 if (innerDims.contains(i))
6299 if (!inverseOuterDimsPerm.empty())
6300 srcPos = inverseOuterDimsPerm[destPos];
6301 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6302 ShapedType::isDynamic(destShape[destPos])) {
6305 int64_t size = srcShape[srcPos];
6306 if (ShapedType::isDynamic(size))
6307 size = destShape[destPos];
6308 srcShape[srcPos] = size;
6309 destShape[destPos] = size;
6310 changeNeeded =
true;
6312 return changeNeeded;
6315LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6318 if (!unPackOp.hasPureTensorSemantics())
6322 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6323 if (packOp.getSourceType() != unPackOp.getDestType())
6325 if (packOp.getPaddingValue() ||
6329 rewriter.
replaceOp(unPackOp, packOp.getSource());
6333 if (
auto dstStyleOp =
6334 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6335 auto destValue = cast<OpResult>(unPackOp.getDest());
6336 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6338 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6342 if (unPackOp->hasOneUse()) {
6343 auto extractSliceUser =
6344 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6345 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6348 auto newDest = tensor::ExtractSliceOp::create(
6349 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6350 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6351 extractSliceUser.getMixedStrides());
6353 unPackOp.setDpsInitOperand(0, newDest);
6354 unPackOp.getResult().setType(newDest.
getType());
6356 rewriter.
replaceOp(extractSliceUser, unPackOp);
6365 Value source = unPackOp.getSource();
6366 if (srcShape != unPackOp.getSourceType().getShape()) {
6367 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6368 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6369 unPackOp.getSource());
6371 Value dest = unPackOp.getDest();
6372 if (destShape != unPackOp.getDestType().getShape()) {
6373 auto newDestType = unPackOp.getDestType().clone(destShape);
6374 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6375 unPackOp.getDest());
6377 UnPackOp newOp = UnPackOp::create(
6378 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6379 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6381 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6388bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6390 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6395 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6399 for (
auto [pos, tileSize] :
6400 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6401 areOuterDimsTiled[pos] =
true;
6402 if (unpackedTypeAfterFold.isDynamicDim(pos))
6404 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6406 if (ShapedType::isDynamic(tileSize))
6408 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6409 unpackedTypeAfterFold.getDimSize(pos);
6410 if (paddingSize >= tileSize)
6414 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6415 if (areOuterDimsTiled[pos])
6417 int64_t dim = outerShapeWithoutTranspose[pos];
6418 if (ShapedType::isDynamic(dim))
6420 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6426bool UnPackOp::isLikeUnPad() {
6427 ShapedType packedTensorType = getSourceType();
6431::mlir::LogicalResult
6432UnPackOp::fold(FoldAdaptor adaptor,
6433 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6435 if (!hasPureTensorSemantics())
6438 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6439 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6440 cast<TensorType>(getResult().
getType()))) {
6441 results.push_back(reshapedSource);
6467 if (!op.hasPureTensorSemantics())
6476 Value sourceTensor = newOperands[0];
6480 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6486 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6487 newOperands[1], op.getInnerDimsPos(),
6488 newMixedTileSizes, op.getOuterDimsPerm());
6489 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6492 Value oldResult = op.getResult();
6493 Value newResult = newOp.getResult();
6496 ? tensor::CastOp::create(rewriter, op->getLoc(),
6497 oldResult.
getType(), newResult)
6511 utils::IteratorType::reduction, utils::IteratorType::parallel,
6512 utils::IteratorType::parallel, utils::IteratorType::reduction};
6516BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
6520 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6521 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6523 return indexingMaps;
6526bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
6527 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6530 if (maps.size() != 3)
6539unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6541std::string BatchReduceMatmulOp::getLibraryCallName() {
6547bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6551 return defaultMaps != explicitMaps;
6561bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
6564 "Expected less than 3 result dim expr.");
6565 bool isValid =
false;
6566 enum Indices { batchPos, mPos, nPos, kPos };
6577 : ((expr0.isFunctionOfDim(batchPos) &&
6578 expr1.isFunctionOfDim(kPos)) ||
6579 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6584void BatchReduceMatmulOp::regionBuilder(
6588 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6593 "BatchReduceMatmulOp regionBuilder expects 3 args");
6594 RegionBuilderHelper helper(
b, block);
6599 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6601 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6603 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6604 if (!castValA || !castValB || !mulVal)
6607 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6610 yields.push_back(addVal);
6611 helper.yieldOutputs(yields);
6614ParseResult BatchReduceMatmulOp::parse(
OpAsmParser &parser,
6627 if (!isa<AffineMapAttr>(mapAttr)) {
6629 "expected affine map attribute");
6631 indexingMapsAttr.push_back(mapAttr);
6641 if (indexingMapsAttr.empty()) {
6642 indexingMapsAttr = llvm::map_to_vector(
6643 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6646 result.addAttribute(
"indexing_maps",
6648 return ::parseNamedStructuredOp(parser,
result,
6649 BatchReduceMatmulOp::getNumRegionArgs(),
6650 BatchReduceMatmulOp::getRegionBuilder());
6655 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6658 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6659 p <<
" indexing_maps = [";
6660 llvm::interleaveComma(getIndexingMaps(), p,
6666 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6672LogicalResult BatchReduceMatmulOp::verify() {
6675 if (!hasUserDefinedMaps())
6678 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6684LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6688void BatchReduceMatmulOp::getEffects(
6691 if (hasPureTensorSemantics())
6707void LinalgDialect::getCanonicalizationPatterns(
6716 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override