41#include "llvm/ADT/DenseMap.h"
42#include "llvm/ADT/STLExtras.h"
43#include "llvm/ADT/SetOperations.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/SmallVectorExtras.h"
46#include "llvm/ADT/StringSet.h"
47#include "llvm/ADT/TypeSwitch.h"
48#include "llvm/Support/FormatVariadic.h"
49#include "llvm/Support/InterleavedRange.h"
50#include "llvm/Support/LogicalResult.h"
51#include "llvm/Support/MathExtras.h"
52#include "llvm/Support/raw_ostream.h"
62 auto type = cast<ShapedType>(v.
getType());
63 if (!type.isDynamicDim(dim))
68 .Case([&](RankedTensorType t) ->
Value {
69 return tensor::DimOp::create(builder, loc, v, dim);
71 .Case([&](MemRefType t) ->
Value {
72 return memref::DimOp::create(builder, loc, v, dim);
83 .Case([&](RankedTensorType t) ->
Operation * {
84 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
87 .Case([&](MemRefType type) ->
Operation * {
88 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
94static std::optional<TypedAttr>
98 if (!splatAttr || !splatAttr.
isSplat())
110 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
111 return b.createOrFold<memref::DimOp>(loc, source, dim);
112 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
113 return b.createOrFold<tensor::DimOp>(loc, source, dim);
114 llvm_unreachable(
"Expected MemRefType or TensorType");
119 auto shapedType = llvm::cast<ShapedType>(source.
getType());
120 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
122 return b.getIndexAttr(shapedType.getDimSize(dim));
145 for (
auto containers : {inputTypes, outputTypes}) {
146 for (
auto t : containers) {
158 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
174 std::optional<TypeRange> resultTensorTypes,
181 if (!resultTensorTypes)
182 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
183 llvm::IsaPred<RankedTensorType>);
191 "operandSegmentSizes",
192 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
193 static_cast<int32_t>(outputs.size())}));
203 std::optional<TypeRange> resultTensorTypes,
210 return attr.
getName() ==
"indexing_maps";
213 indexingMapsAttrVal = llvm::map_to_vector(
216 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
219 attributes, regionBuilder);
223 std::optional<TypeRange> resultTensorTypes,
230 return attr.
getName() ==
"indexing_maps";
233 indexingMapsAttrVal = llvm::map_to_vector(
236 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
239 attributes, regionBuilder);
243 std::optional<TypeRange> resultTensorTypes,
250 indexingMapsAttrVal =
252 return AffineMapAttr::get(map);
254 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
256 attributes, regionBuilder);
265 bool addOperandSegmentSizes =
true) {
266 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
295 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
297 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
301 if (addOperandSegmentSizes) {
308 if (
result.propertiesAttr) {
310 attrs.
append(
"operandSegmentSizes",
312 {static_cast<int32_t>(inputsOperands.size()),
313 static_cast<int32_t>(outputsOperands.size())}));
316 result.addAttribute(
"operandSegmentSizes",
318 {static_cast<int32_t>(inputsOperands.size()),
319 static_cast<int32_t>(outputsOperands.size())}));
322 if (!
result.propertiesAttr) {
323 std::optional<RegisteredOperationName> info =
324 result.name.getRegisteredInfo();
326 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
327 return parser.emitError(attrsLoc)
328 <<
"'" << result.name.getStringRef() <<
"' op ";
339 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
340 if (!outputs.empty())
341 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
352 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
355 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
356 "region expects {0} args, got {1}",
357 numRegionArgs, inputTypes.size() + outputTypes.size()));
363 opBuilder, region, inputTypes, outputTypes, attrs,
382 unsigned numRegionArgs,
399 result.addTypes(outputTensorsTypes);
401 std::unique_ptr<Region> region = std::make_unique<Region>();
403 outputTypes,
result.attributes.getAttrs(),
406 result.addRegion(std::move(region));
413 if (resultTypes.empty())
458class RegionBuilderHelper {
460 RegionBuilderHelper(OpBuilder &builder,
Block &block)
461 : builder(builder), block(block) {}
464 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
466 if (!isFloatingPoint(arg)) {
468 emitError() <<
"unsupported non numeric type";
471 llvm_unreachable(
"unsupported non numeric type");
473 OpBuilder::InsertionGuard g(builder);
474 builder.setInsertionPointToEnd(&block);
477 return math::ExpOp::create(builder, arg.
getLoc(), arg);
479 return math::LogOp::create(builder, arg.
getLoc(), arg);
481 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
483 return math::CeilOp::create(builder, arg.
getLoc(), arg);
485 return math::FloorOp::create(builder, arg.
getLoc(), arg);
487 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
488 case UnaryFn::reciprocal: {
489 Attribute oneAttr = builder.getOneAttr(arg.
getType());
490 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
491 ::cast<TypedAttr>(oneAttr));
492 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
495 return math::RoundOp::create(builder, arg.
getLoc(), arg);
497 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
499 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
500 case UnaryFn::square:
501 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
503 return math::TanhOp::create(builder, arg.
getLoc(), arg);
505 return math::ErfOp::create(builder, arg.
getLoc(), arg);
507 return math::SinOp::create(builder, arg.
getLoc(), arg);
509 return math::CosOp::create(builder, arg.
getLoc(), arg);
511 return math::TanOp::create(builder, arg.
getLoc(), arg);
514 emitError() <<
"unsupported unary function";
517 llvm_unreachable(
"unsupported unary function");
524 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
526 bool allComplex = isComplex(arg0) && isComplex(arg1);
527 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
528 bool allInteger = isInteger(arg0) && isInteger(arg1);
531 if (!allComplex && !allFloatingPoint && !allInteger) {
534 <<
"Cannot build binary Linalg operation: expects allComplex, "
535 "allFloatingPoint, or allInteger, got "
539 llvm_unreachable(
"unsupported non numeric type");
541 OpBuilder::InsertionGuard g(builder);
542 builder.setInsertionPointToEnd(&block);
546 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
547 if (allFloatingPoint)
548 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
554 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 if (allFloatingPoint)
556 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
559 emitError() <<
"unsupported operation: sub with bools";
562 llvm_unreachable(
"unsupported operation: sub with bools");
564 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
567 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
568 if (allFloatingPoint)
569 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
571 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
572 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
575 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
576 if (allFloatingPoint)
577 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 emitError() <<
"unsupported operation: div with bools";
583 llvm_unreachable(
"unsupported operation: div with bools");
585 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
586 case BinaryFn::div_unsigned:
587 if (!allInteger || allBool) {
589 emitError() <<
"unsupported operation: unsigned div not on uint";
592 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
594 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
595 case BinaryFn::max_signed:
597 if (allFloatingPoint)
598 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
599 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
600 case BinaryFn::min_signed:
602 if (allFloatingPoint)
603 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
604 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
605 case BinaryFn::max_unsigned:
607 if (!allInteger || allBool) {
609 emitError() <<
"unsupported operation: unsigned max not on uint";
612 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
614 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
615 case BinaryFn::min_unsigned:
617 if (!allInteger || allBool) {
619 emitError() <<
"unsupported operation: unsigned min not on uint";
622 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
624 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
626 assert(allFloatingPoint);
627 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
630 emitError() <<
"unsupported binary function";
633 llvm_unreachable(
"unsupported binary function");
637 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
639 OpBuilder::InsertionGuard g(builder);
640 builder.setInsertionPointToEnd(&block);
642 case TernaryFn::select:
643 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
646 emitError() <<
"unsupported ternary function";
649 llvm_unreachable(
"unsupported ternary function");
653 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
656 case TypeFn::cast_signed:
657 return cast(toType, operand,
false);
658 case TypeFn::cast_unsigned:
659 return cast(toType, operand,
true);
662 emitError() <<
"unsupported type conversion function";
665 llvm_unreachable(
"unsupported type conversion function");
669 OpBuilder::InsertionGuard g(builder);
670 builder.setInsertionPointToEnd(&block);
671 Location loc = builder.getUnknownLoc();
672 YieldOp::create(builder, loc, values);
675 Value constant(
const std::string &value) {
676 OpBuilder::InsertionGuard g(builder);
677 builder.setInsertionPointToEnd(&block);
678 Location loc = builder.getUnknownLoc();
679 Attribute valueAttr =
parseAttribute(value, builder.getContext());
680 return arith::ConstantOp::create(builder, loc,
681 ::cast<TypedAttr>(valueAttr));
684 Value index(int64_t dim) {
685 OpBuilder::InsertionGuard g(builder);
686 builder.setInsertionPointToEnd(&block);
687 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
690 Type getIntegerType(
unsigned width) {
691 return IntegerType::get(builder.getContext(), width);
694 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
695 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
702 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
703 OpBuilder::InsertionGuard g(builder);
704 builder.setInsertionPointToEnd(&block);
705 auto loc = operand.
getLoc();
706 if (isa<UnknownLoc>(loc)) {
716 bool isComplex(Value value) {
717 return llvm::isa<ComplexType>(value.
getType());
719 bool isFloatingPoint(Value value) {
720 return llvm::isa<FloatType>(value.
getType());
722 bool isInteger(Value value) {
723 return llvm::isa<IntegerType>(value.
getType());
739 using OpRewritePattern<CopyOp>::OpRewritePattern;
740 LogicalResult matchAndRewrite(CopyOp copyOp,
741 PatternRewriter &rewriter)
const override {
742 if (copyOp.getInputs() != copyOp.getOutputs())
744 if (copyOp.hasPureBufferSemantics())
747 rewriter.
replaceOp(copyOp, copyOp.getInputs());
757 results.
add<EraseSelfCopy>(context);
770template <
typename TensorReshapeOp>
771struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
772 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
773 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
774 PatternRewriter &rewriter)
const override {
775 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
779 Location loc = oldFill.getLoc();
780 TensorReshapeOp newInit;
781 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
783 newInit = TensorReshapeOp::create(
784 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
785 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
786 reshapeOp.getStaticOutputShape());
788 newInit = TensorReshapeOp::create(
789 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
790 reshapeOp.getReassociation());
800struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
803 LogicalResult matchAndRewrite(tensor::PadOp padOp,
804 PatternRewriter &rewriter)
const override {
805 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
811 Value padValue = padOp.getConstantPaddingValue();
812 if (!padValue || fillOp.value() != padValue)
818 padOp,
"failed to reify tensor.pad op result shape");
821 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
822 padOp.getResultType().getElementType());
824 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
827 if (
replacement.getType() != padOp.getResultType()) {
828 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
839struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
842 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
843 PatternRewriter &rewriter)
const override {
844 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
848 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
853 Value firstDest = insertOp.getDest();
854 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
855 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
860 bool disjoint =
false;
861 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
864 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
865 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
866 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
870 int64_t prevStart = prevOp.getStaticOffset(i);
871 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
872 prevOp.getStaticStride(i);
873 int64_t nextStart = insertOp.getStaticOffset(i);
874 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
875 insertOp.getStaticStride(i);
876 if (prevEnd < nextStart || nextEnd < prevStart) {
884 firstDest = prevOp.getDest();
895 Value padValue = srcPadOp.getConstantPaddingValue();
896 if (!padValue || dstFillOp.value() != padValue)
899 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
900 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
902 Location loc = insertOp.getLoc();
905 AffineExpr sym0, sym1;
911 SmallVector<OpFoldResult, 4> newOffsets;
912 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
914 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
917 RankedTensorType srcPadType = srcPadOp.getSourceType();
918 SmallVector<OpFoldResult, 4> newSizes;
919 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
920 if (srcPadType.isDynamicDim(i)) {
922 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
925 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
930 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
931 newSizes, insertOp.getMixedStrides());
937struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
939 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
941 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
942 PatternRewriter &rewriter)
const override {
945 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
950 Value extractedScalar = fillOp.getInputs()[0];
953 rewriter.
replaceOp(extractOp, extractedScalar);
961static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
962 linalg::PackOp packOp) {
963 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
967 if (
auto paddingValue = packOp.getPaddingValue())
971 Value packOpDest = packOp.getDest();
975 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
980struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
982 FoldFillWithPack(MLIRContext *context)
983 : OpRewritePattern<linalg::PackOp>(context) {}
985 LogicalResult matchAndRewrite(linalg::PackOp packOp,
986 PatternRewriter &rewriter)
const override {
987 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
990 rewriter.
replaceOp(packOp, fillOp.value().result());
996struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
997 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
999 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
1000 PatternRewriter &rewriter)
const override {
1001 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
1004 copyOp.getOutputs());
1007 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
1009 fillOp.getOutputs());
1017struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1018 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1020 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1021 PatternRewriter &rewriter)
const override {
1022 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1024 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1025 transposeOp.getDpsInitOperand(0)->get());
1034struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1037 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1038 PatternRewriter &rewriter)
const override {
1039 auto concatOperands = concatOp.getInputs();
1040 if (concatOperands.empty()) {
1044 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1049 OpFoldResult firstFillVal =
1052 SmallVector<Value> allOuts;
1053 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1055 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1056 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1061 OpFoldResult fillVal =
1063 if (fillVal != firstFillVal)
1066 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1069 if (!llvm::all_of(concatOperands.drop_front(),
1070 isDefinedByCompatibleFillOp)) {
1072 concatOp,
"not all operands are defined by a compatible fill op");
1075 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1076 concatOp.getDim(), allOuts);
1078 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1085void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1086 MLIRContext *context) {
1087 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1088 FoldFillWithPack, FoldFillWithPad,
1089 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1090 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1091 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1104 for (
ValueRange container : {inputs, outputs}) {
1105 for (
Value v : container) {
1106 Type t = v.getType();
1107 blockArgTypes.push_back(
1109 blockArgLocs.push_back(v.getLoc());
1115 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1119void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1121 for (Value v : getRegionInputArgs())
1123 for (Value v : getRegionOutputArgs())
1124 setNameFn(v,
"out");
1127void GenericOp::build(
1128 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1130 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1132 ArrayRef<NamedAttribute> attributes) {
1133 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1134 iteratorTypes, doc, libraryCall);
1135 result.addAttributes(attributes);
1138 inputs, outputs, bodyBuild);
1141void GenericOp::build(
1142 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1144 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1145 StringRef libraryCall,
1147 ArrayRef<NamedAttribute> attributes) {
1148 build(builder,
result, resultTensorTypes, inputs, outputs,
1152 [&](utils::IteratorType iter) -> mlir::Attribute {
1153 return IteratorTypeAttr::get(builder.getContext(), iter);
1156 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1157 bodyBuild, attributes);
1160void GenericOp::build(
1162 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1163 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1164 StringRef libraryCall,
1166 ArrayRef<NamedAttribute> attributes) {
1168 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1171void GenericOp::build(
1173 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1174 ArrayRef<utils::IteratorType> iteratorTypes,
1176 ArrayRef<NamedAttribute> attributes) {
1177 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1179 "", bodyBuild, attributes);
1182void GenericOp::build(
1183 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1185 ArrayRef<utils::IteratorType> iteratorTypes,
1187 ArrayRef<NamedAttribute> attributes) {
1188 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1191 "", bodyBuild, attributes);
1194void GenericOp::print(OpAsmPrinter &p) {
1198 auto genericAttrNames = linalgTraitAttrNames();
1200 llvm::StringSet<> genericAttrNamesSet;
1201 genericAttrNamesSet.insert_range(genericAttrNames);
1202 SmallVector<NamedAttribute, 8> genericAttrs;
1203 for (
auto attr : (*this)->getAttrs()) {
1204 if (attr.getName() == getIteratorTypesAttrName()) {
1205 auto iteratorTypes =
1206 llvm::cast<ArrayAttr>(attr.getValue())
1207 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1212 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1213 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1214 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1217 genericAttrs.emplace_back(
1218 getIteratorTypesAttrName(),
1219 ArrayAttr::get(
getContext(), iteratorTypeNames));
1220 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1221 genericAttrs.push_back(attr);
1224 if (!genericAttrs.empty()) {
1225 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1226 p << genericDictAttr;
1232 genericAttrNames.push_back(
"operandSegmentSizes");
1233 genericAttrNamesSet.insert(genericAttrNames.back());
1235 bool hasExtraAttrs =
false;
1236 for (NamedAttribute n : (*this)->getAttrs()) {
1237 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1240 if (hasExtraAttrs) {
1247 if (!getRegion().empty()) {
1256ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1257 DictionaryAttr dictAttr;
1265 result.attributes.assign(dictAttr.getValue().begin(),
1266 dictAttr.getValue().end());
1272 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1273 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1274 if (!iteratorTypes) {
1275 return parser.
emitError(attributeLocation)
1276 <<
"expected " << getIteratorTypesAttrName(
result.name)
1277 <<
" array attribute";
1280 SmallVector<Attribute> iteratorTypeAttrs;
1282 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1283 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1284 if (!maybeIteratorType.has_value())
1286 <<
"unexpected iterator_type (" << s <<
")";
1288 iteratorTypeAttrs.push_back(
1289 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1291 result.attributes.set(getIteratorTypesAttrName(
result.name),
1295 SmallVector<Type, 1> inputTypes, outputTypes;
1305 std::unique_ptr<Region> region = std::make_unique<Region>();
1308 result.addRegion(std::move(region));
1314 SmallVector<Type, 1> outputTensorsTypes;
1317 result.addTypes(outputTensorsTypes);
1325 LinalgOp linalgOp) {
1326 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1327 if (!llvm::isa<MemRefType>(operand.
getType()))
1329 effects.emplace_back(
1334 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1335 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1337 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1348void GenericOp::getEffects(
1349 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1358 if (!linalgOp.hasPureTensorSemantics())
1376template <
typename OpTy>
1377struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1378 using OpRewritePattern<OpTy>::OpRewritePattern;
1380 LogicalResult matchAndRewrite(OpTy linalgOp,
1381 PatternRewriter &rewriter)
const override {
1383 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1388 Block &body = linalgOp->getRegion(0).front();
1389 if (!llvm::hasSingleElement(body))
1391 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1396 if (linalgOp.hasPureBufferSemantics()) {
1397 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1398 linalgOp.getDpsInputOperand(0)->get() !=
1399 linalgOp.getDpsInitOperand(0)->get()) {
1401 linalgOp,
"expected single input and output to be the same value");
1404 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1405 if (!yieldArg || yieldArg.getOwner() != &body) {
1407 "cannot fold fill-like op");
1414 if (!linalgOp.hasPureTensorSemantics()) {
1416 linalgOp,
"mixed semantics is not supported yet");
1421 SmallVector<Value> returnedArgs;
1422 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1423 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1424 if (!yieldArg || yieldArg.getOwner() != &body)
1426 unsigned argumentNumber = yieldArg.getArgNumber();
1427 Value returnedArg = linalgOp->getOperand(argumentNumber);
1428 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1431 Type returnType = returnedArg.
getType();
1432 if (returnType != resultType) {
1437 returnedArg = sparse_tensor::ConvertOp::create(
1438 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1440 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1443 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1444 resultType, returnedArg);
1447 returnedArgs.push_back(returnedArg);
1450 if (returnedArgs.size() != linalgOp->getNumResults())
1452 rewriter.
replaceOp(linalgOp, returnedArgs);
1459void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1460 MLIRContext *context) {
1461 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1464LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1483 for (
Type outputType : outputTypes) {
1484 if (llvm::isa<RankedTensorType>(outputType))
1485 result.addTypes(outputType);
1489 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1498void MapOp::getAsmBlockArgumentNames(Region ®ion,
1500 for (Value v : getRegionInputArgs())
1502 for (Value v : getRegionOutputArgs())
1503 setNameFn(v,
"init");
1506void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1507 if (!getResults().empty())
1508 setNameFn(getResults().front(),
"mapped");
1514 ArrayRef<NamedAttribute> attributes) {
1516 result.addAttributes(attributes);
1519 Type initType = init.
getType();
1520 if (llvm::isa<RankedTensorType>(initType))
1521 result.addTypes(initType);
1525 inputs, {init}, bodyBuild);
1532 bool initFirst =
false,
bool mapInit =
true) {
1536 b.setInsertionPointToStart(&block);
1537 for (
auto &operand : operands) {
1539 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1547 payloadOpOperands.push_back(block.
getArguments().back());
1548 for (
const auto &arg : block.
getArguments().drop_back())
1549 payloadOpOperands.push_back(arg);
1558 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1564ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1565 std::optional<OperationName> payloadOpName;
1566 NamedAttrList payloadOpAttrs;
1569 if (
failed(operationName))
1573 payloadOpName = operationName.value();
1581 if (payloadOpName.has_value()) {
1582 if (!
result.operands.empty())
1584 payloadOpAttrs, ArrayRef(
result.operands),
false,
1589 SmallVector<OpAsmParser::Argument> regionArgs;
1594 Region *body =
result.addRegion();
1602 bool mapInit =
true) {
1604 if (initFirst && !mapInit)
1628 for (
const auto &[operand, bbArg] :
1630 if (bbArg != operand)
1634 for (
const auto &[operand, bbArg] :
1637 if (bbArg != operand)
1644 return yieldOp.getNumOperands() == 1 &&
1645 yieldOp.getOperand(0).getDefiningOp() &&
1646 yieldOp.getOperand(0).getDefiningOp() == &payload;
1651 std::string attrToElide;
1653 for (
const auto &attr : payloadOp->
getAttrs()) {
1655 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1656 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1657 attrToElide = attr.getName().str();
1658 elidedAttrs.push_back(attrToElide);
1666void MapOp::print(OpAsmPrinter &p) {
1667 Block *mapper = getBody();
1677 if (!useShortForm) {
1683 [&](
auto arg) { p.printRegionArgument(arg); });
1691LogicalResult MapOp::verify() {
1692 auto *bodyBlock = getBody();
1693 auto blockArgs = bodyBlock->getArguments();
1697 if (getInputs().size() + 1 != blockArgs.size())
1698 return emitOpError() <<
"expects number of operands to match the arity of "
1700 << getInputs().size() + 1 <<
" and "
1701 << blockArgs.size();
1704 for (
const auto &[bbArgType, inputArg] :
1705 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1706 auto inputElemType =
1707 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1708 if (bbArgType != inputElemType) {
1709 return emitOpError() <<
"expected element type of input " << inputElemType
1710 <<
" to match bbArg type " << bbArgType;
1715 auto outputShape = getInit().getType().getShape();
1716 for (Type inputArgType :
TypeRange{getInputs()}) {
1717 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1718 if (inputElemShape != outputShape) {
1719 return emitOpError() <<
"expected shape of input (" << inputElemShape
1720 <<
") to match shape of output (" << outputShape
1728SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1729 int64_t rank = getInit().getType().getRank();
1730 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1735 int64_t rank = getInit().getType().getRank();
1736 int64_t numIndexingMaps = getOperands().size();
1741void MapOp::getEffects(
1742 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1755void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1757 for (Value v : getRegionInputArgs())
1759 for (Value v : getRegionOutputArgs())
1760 setNameFn(v,
"init");
1763void ReduceOp::getAsmResultNames(
1765 if (!getResults().empty())
1766 setNameFn(getResults().front(),
"reduced");
1769void ReduceOp::build(
1771 ValueRange inits, ArrayRef<int64_t> dimensions,
1773 ArrayRef<NamedAttribute> attributes) {
1775 result.addAttributes(attributes);
1778 for (Value init : inits) {
1779 Type initType = init.
getType();
1780 if (llvm::isa<RankedTensorType>(initType))
1781 result.addTypes(initType);
1786 inputs, inits, bodyBuild);
1789SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1791 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1792 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1793 utils::IteratorType::parallel);
1794 for (int64_t reductionDim : getDimensions())
1795 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1796 return iteratorTypes;
1801 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1802 SmallVector<AffineMap> affineMaps(
1805 AffineMap resultMap =
1808 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1809 affineMaps.push_back(resultMap);
1810 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1813void ReduceOp::getEffects(
1814 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1825 StringRef attributeName) {
1833ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1834 std::optional<OperationName> payloadOpName;
1835 NamedAttrList payloadOpAttrs;
1838 if (
failed(operationName))
1842 payloadOpName = operationName.value();
1848 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1853 if (payloadOpName.has_value()) {
1855 ArrayRef(
result.operands),
true);
1857 SmallVector<OpAsmParser::Argument> regionArgs;
1863 Region *body =
result.addRegion();
1873 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1876void ReduceOp::print(OpAsmPrinter &p) {
1877 Block *mapper = getBody();
1886 if (!useShortForm) {
1892 [&](
auto arg) { p.printRegionArgument(arg); });
1900LogicalResult ReduceOp::verify() {
1901 ArrayRef<int64_t> dimensionsRef = getDimensions();
1908 if (getInputs().size() !=
static_cast<size_t>(getNumDpsInputs()))
1910 <<
"expected equal number of inputs and outputs (required by "
1911 "SameVariadicOperandSize), got "
1912 << getNumDpsInputs() <<
" input(s) and " << getNumDpsInits()
1915 if (getInputs().empty())
1916 return emitOpError() <<
"expected at least one input";
1918 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1921 return emitOpError() <<
"expects all inputs to have the same shapes. "
1922 "Shape at input-index "
1924 <<
" is not equal to the shape at input-index 0.";
1927 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1930 return emitOpError() <<
"expects all outputs to have the same shapes. "
1931 "Shape at output-index "
1933 <<
" is not equal to the shape at output-index 0.";
1936 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1937 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1940 for (int64_t dimension : dimensionsRef) {
1941 if (dimension < 0 || dimension >= inputType.getRank()) {
1943 <<
"dimensions for reduction should be in the range [0, "
1944 << inputType.getRank() - 1 <<
"].";
1946 dimensionsToReduce.insert(dimension);
1949 auto inputDims = inputType.getShape();
1950 auto initDims = initType.getShape();
1953 SmallVector<int64_t> reducedInputDims;
1954 for (
const auto &en : llvm::enumerate(inputDims)) {
1955 if (!dimensionsToReduce.count(en.index()))
1956 reducedInputDims.push_back(en.value());
1959 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1960 return emitOpError() <<
"number of dimensions after reduction "
1961 << reducedInputDims.size()
1962 <<
" doesn't match the init rank "
1963 << initType.getRank();
1966 if (reducedInputDims != initDims)
1967 return emitOpError() <<
"init dimensions [" << initDims
1968 <<
"] doesn't match input dimensions after reduction ["
1969 << reducedInputDims <<
"]";
1971 Block *block = getBody();
1974 <<
"mismatching number of operands and block arguments";
1977 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1978 Type inputElementType =
1979 llvm::cast<ShapedType>(input.getType()).getElementType();
1980 if (inputElementType != bbArg.getType())
1982 <<
"input element type " << inputElementType
1983 <<
" does not match corresponding block argument type "
1988 for (
auto [output, bbArg] : llvm::zip(
1989 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1990 auto outputElementType =
1991 llvm::cast<ShapedType>(output.getType()).getElementType();
1992 if (outputElementType != bbArg.getType())
1994 <<
"output element type " << outputElementType
1995 <<
" does not match corresponding block argument type "
2011 linalg::YieldOp::create(
b, loc, args[0]);
2015void TransposeOp::build(::mlir::OpBuilder &builder,
2016 ::mlir::OperationState &
result, Value input, Value init,
2018 ArrayRef<NamedAttribute> attributes) {
2019 result.addOperands(input);
2020 result.addOperands(init);
2021 result.addAttribute(getPermutationAttrName(
result.name), permutation);
2022 result.addAttributes(attributes);
2025 Type initType = init.
getType();
2026 if (llvm::isa<RankedTensorType>(initType))
2027 result.addTypes(initType);
2033void TransposeOp::build(::mlir::OpBuilder &builder,
2034 ::mlir::OperationState &
result, Value input, Value init,
2035 ArrayRef<int64_t> permutation,
2036 ArrayRef<NamedAttribute> attributes) {
2041ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2043 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2055void TransposeOp::getAsmResultNames(
2057 if (!getResults().empty())
2058 setNameFn(getResults().front(),
"transposed");
2061void TransposeOp::print(OpAsmPrinter &p) {
2067LogicalResult TransposeOp::verify() {
2068 ArrayRef<int64_t> permutationRef = getPermutation();
2073 auto inputType = getInput().getType();
2074 auto initType = getInit().getType();
2076 int64_t rank = inputType.getRank();
2082 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2083 return emitOpError() <<
"size of permutation " << permutationRef.size()
2084 <<
" does not match the argument rank " << rank;
2086 auto inputDims = inputType.getShape();
2087 auto initDims = initType.getShape();
2089 for (int64_t i = 0; i < rank; ++i) {
2090 int64_t inputDim = inputDims[permutationRef[i]];
2091 int64_t initDim = initDims[i];
2093 if (inputDim != initDim) {
2094 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2095 <<
" doesn't match dim(input, permutation[" << i
2096 <<
"]) = " << inputDim;
2103SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2104 int64_t rank = getInit().getType().getRank();
2105 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2108ArrayAttr TransposeOp::getIndexingMaps() {
2110 int64_t rank = getInit().getType().getRank();
2113 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2117void TransposeOp::getEffects(
2118 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2127LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2128 SmallVectorImpl<OpFoldResult> &
result) {
2130 if (!isa<TensorType>(getInput().
getType()))
2134 if (getPermutation().empty()) {
2135 result.push_back(getInput());
2140 result.push_back(getInput());
2153 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2154 if (!defTransposeOp)
2159 foldedPerms.reserve(perms.size());
2161 foldedPerms.push_back(defPerms[perm]);
2164 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2177 if (!transposeOp.hasPureTensorSemantics())
2182 if (!splatValue.has_value())
2186 cast<RankedTensorType>(transposeOp.getResult()[0].getType());
2203 Value input = transposeOp.getInput();
2204 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2215 unsigned dimensionSize = dimensions.size();
2216 for (
unsigned i = 0; i < dimensionSize; ++i)
2217 resultDimensions.push_back(invertPerm[dimensions[i]]);
2220 Value broadcastInput = broadcastOp.getInput();
2221 Location loc = transposeOp.getLoc();
2224 auto broadcastInputTy =
2225 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2226 unsigned inputRank = broadcastInputTy.getRank();
2227 for (
unsigned i = 0; i < inputRank; ++i) {
2228 if (broadcastInputTy.isDynamicDim(i)) {
2229 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2232 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2233 broadcastInputTy.getDimSize(i)));
2238 Value transposeInit = tensor::EmptyOp::create(
2239 rewriter, transposeOp.getLoc(), transposeResultShapes,
2240 broadcastInputTy.getElementType());
2243 Value transposeResult =
2244 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2245 transposeInit, resultPerms)
2248 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2253void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2254 MLIRContext *context) {
2255 results.
add<FoldTransposeWithTranspose, FoldTransposeSplatConstant,
2256 SwapTransposeWithBroadcast>(context);
2263void BroadcastOp::build(::mlir::OpBuilder &builder,
2264 ::mlir::OperationState &
result, Value input, Value init,
2266 ArrayRef<NamedAttribute> attributes) {
2267 result.addOperands(input);
2268 result.addOperands(init);
2269 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2270 result.addAttributes(attributes);
2273 Type initType = init.
getType();
2274 if (llvm::isa<RankedTensorType>(initType))
2275 result.addTypes(initType);
2281void BroadcastOp::build(::mlir::OpBuilder &builder,
2282 ::mlir::OperationState &
result, Value input, Value init,
2283 ArrayRef<int64_t> dimensions,
2284 ArrayRef<NamedAttribute> attributes) {
2289ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2291 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2303void BroadcastOp::getAsmResultNames(
2305 if (!getResults().empty())
2306 setNameFn(getResults().front(),
"broadcasted");
2309void BroadcastOp::print(OpAsmPrinter &p) {
2315LogicalResult BroadcastOp::verify() {
2316 ArrayRef<int64_t> dimensionsRef = getDimensions();
2318 auto inputType = getInput().getType();
2319 auto initType = getInit().getType();
2321 int64_t inputRank = inputType.getRank();
2322 int64_t initRank = initType.getRank();
2324 auto inputShape = inputType.getShape();
2325 auto initShape = initType.getShape();
2327 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2328 return emitOpError() <<
"input rank plus added dimensions does not "
2329 "match init rank. input rank: "
2331 <<
", dimensions size: " << dimensionsRef.size()
2332 <<
", init rank: " << initRank;
2334 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2335 if (dim < 0 || dim >= initRank)
2337 <<
" is out of range. expected range: [0, "
2338 << initRank - 1 <<
"], got: " << dim;
2342 SmallVector<int64_t> dimMap;
2343 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2344 if (!llvm::is_contained(dimensionsRef, dim))
2345 dimMap.push_back(dim);
2348 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2351 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2352 return emitOpError() <<
"input dim " << inputDimIdx
2353 <<
" should match init dim " << initDimIdx
2354 <<
". input: " << inputShape[inputDimIdx]
2355 <<
", init: " << initShape[initDimIdx];
2361SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2362 int64_t rank = getInit().getType().getRank();
2363 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2366ArrayAttr BroadcastOp::getIndexingMaps() {
2368 int64_t rank = getInit().getType().getRank();
2374void BroadcastOp::getEffects(
2375 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2390 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2391 if (!defBroadcastOp)
2396 Value init = broadcastOp.getInit();
2400 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2401 if (!llvm::is_contained(dimensions, dim))
2402 dimMap.push_back(dim);
2404 for (
auto dim : defDimensions)
2405 foldedDims.push_back(dimMap[dim]);
2407 llvm::sort(foldedDims);
2409 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2421 if (!broadcastOp.hasPureTensorSemantics())
2427 if (!splatValue.has_value())
2431 cast<RankedTensorType>(broadcastOp.getResult()[0].getType());
2432 if (!resultType.hasStaticShape())
2434 "result type has dynamic shape");
2443void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2444 MLIRContext *context) {
2445 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts,
2446 FoldBroadcastSplatConstant>(context);
2453void linalg::YieldOp::print(OpAsmPrinter &p) {
2454 if (getNumOperands() > 0)
2455 p <<
' ' << getOperands();
2457 if (getNumOperands() > 0)
2458 p <<
" : " << getOperandTypes();
2461ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2462 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2463 SmallVector<Type, 2> types;
2473static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2474 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2475 return op.emitOpError(
"expected number of yield values (")
2476 << op.getNumOperands()
2477 <<
") to match the number of inits / outs operands of the enclosing "
2478 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2480 for (
OpOperand &opOperand : op->getOpOperands()) {
2482 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2484 if (isa<MemRefType, RankedTensorType>(elementType))
2486 if (opOperand.get().getType() != elementType)
2487 return op.emitOpError(
"type of yield operand ")
2488 << (opOperand.getOperandNumber() + 1) <<
" ("
2489 << opOperand.get().getType() <<
") doesn't match "
2490 <<
"the element type of the enclosing linalg.generic op ("
2491 << elementType <<
")";
2496LogicalResult linalg::YieldOp::verify() {
2497 auto *parentOp = (*this)->getParentOp();
2498 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2499 return emitOpError(
"expected single non-empty parent region");
2501 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2504 return emitOpError(
"expected parent op with LinalgOp interface");
2511LogicalResult IndexOp::verify() {
2512 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2514 return emitOpError(
"expected parent op with LinalgOp interface");
2515 if (linalgOp.getNumLoops() <= getDim())
2517 << getDim() <<
") to be lower than the number of loops ("
2518 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2522OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2523 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2528 return OpFoldResult{};
2531 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2532 uint64_t dim = getDim();
2533 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2534 if (loopBounds[dim] == 1)
2535 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2537 return OpFoldResult{};
2542#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2544#define GET_OP_CLASSES
2545#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2547#define GET_OP_CLASSES
2548#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2549#define GET_OP_CLASSES
2550#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2567 for (
unsigned i = 0; i < num; ++i)
2574 auto rangeA = llvm::make_range(a.begin(), a.end());
2575 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2576 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2577 return llvm::to_vector<4>(concatRanges);
2581 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2583 for (
auto size :
memref.getShape())
2590 if (
auto as =
memref.getMemorySpace()) {
2591 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2592 ss <<
"as" << attr.getInt();
2598 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2601 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2614 assert(isa<LinalgOp>(op));
2616 std::string fun =
"";
2618 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2619 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2620 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2621 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2625 llvm::replace(name,
'.',
'_');
2626 llvm::raw_string_ostream ss(name);
2630 return std::string();
2645 LogicalResult matchAndRewrite(LinalgOp op,
2647 for (
OpOperand &opOperand : op->getOpOperands()) {
2651 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2654 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2665struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2666 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2668 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2669 PatternRewriter &rewriter)
const override {
2673 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2680 if (castOp->getBlock() != linalgOp->getBlock())
2683 OpBuilder::InsertionGuard guard(rewriter);
2686 Location loc = linalgOp.getLoc();
2687 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2690 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2696 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2698 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2699 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2700 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2701 linalgOp.getDpsInits().end());
2702 outputOperands[resultNumber] = newOperand;
2703 newOperands.append(outputOperands.begin(), outputOperands.end());
2705 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2706 linalgOp->result_type_end());
2707 resultTypes[resultNumber] = resultType;
2708 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2711 Value castBack = tensor::CastOp::create(
2715 results[resultNumber] = castBack;
2724static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2725 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2726 for (OpOperand &opOperand : operands) {
2727 if (linalgOp.isScalar(&opOperand))
2729 Value src = opOperand.get();
2730 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2731 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2737 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2739 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2740 Value castSource = castOp.getSource();
2741 auto castSourceType =
2742 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2743 if (castSourceType && castSourceType.hasStaticShape())
2744 sourceShape = castSourceType.getShape();
2750 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2751 if (sourceType.isDynamicDim(i))
2753 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2754 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2764static void createNewOperandWithStaticSizes(
2765 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2766 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2767 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2768 bool &changeNeeded) {
2769 Value src = opOperand->
get();
2770 newOperands.push_back(src);
2771 if (linalgOp.isScalar(opOperand))
2773 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2774 Type resultType = sourceType;
2775 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2776 resultTypes.push_back(resultType);
2779 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2780 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2781 SmallVector<int64_t> newShape;
2784 bool newOperandNeeded =
false;
2785 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2786 int64_t dimShape = sourceShape[i];
2787 AffineExpr dimExpr = sourceMap.
getResult(i);
2788 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2789 newShape.push_back(dimShape);
2795 newShape.push_back(affineExprToSize[dimExpr]);
2796 newOperandNeeded =
true;
2798 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2799 sourceType.getEncoding());
2800 if (newOperandNeeded) {
2801 changeNeeded =
true;
2804 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2806 newOperands[index] = newOperand;
2808 if (linalgOp.isDpsInit(opOperand))
2809 resultTypes.push_back(resultType);
2815struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2816 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2818 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2819 PatternRewriter &rewriter)
const override {
2820 if (!linalgOp.hasPureTensorSemantics())
2824 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2825 return !map.isProjectedPermutation();
2830 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2831 Location loc = linalgOp.getLoc();
2835 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2837 SmallVector<Value> newOperands;
2838 SmallVector<Type> resultTypes;
2842 bool changeNeeded =
false;
2843 newOperands.reserve(linalgOp->getNumOperands());
2844 resultTypes.reserve(linalgOp.getNumDpsInits());
2847 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2848 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2849 affineExprToSize, linalgOp, newOperands,
2850 resultTypes, changeNeeded);
2859 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2860 SmallVector<Value> replacements;
2862 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2863 Value newResult = std::get<1>(it);
2864 Value oldResult = std::get<0>(it);
2865 Type newType = newResult.
getType();
2866 Type oldType = oldResult.
getType();
2867 replacements.push_back(
2868 (newType != oldType)
2869 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2872 rewriter.
replaceOp(linalgOp, replacements);
2886LogicalResult SoftmaxOp::verify() {
2887 ShapedType inputType = getInputOperandType();
2888 ShapedType outputType = getOutputOperandType();
2890 ArrayRef<int64_t> inputShape = inputType.getShape();
2891 ArrayRef<int64_t> outputShape = outputType.getShape();
2895 int64_t inputRank = getInputOperandRank();
2896 int64_t dimension = getDimension();
2897 if ((dimension < 0) || (dimension >= inputRank))
2898 return emitOpError(
"incorrect dimension specified");
2903SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2904 int64_t operandRank = getInputOperandRank();
2905 SmallVector<Range> loopBounds(operandRank);
2906 Location loc = getLoc();
2909 Value source = getInput();
2910 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2911 loopBounds[dim].offset = zero;
2912 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2913 loopBounds[dim].stride = one;
2918SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2919 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2920 utils::IteratorType::parallel);
2921 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2922 return iteratorTypes;
2925FailureOr<TilingResult>
2926SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2927 ArrayRef<OpFoldResult> offsets,
2928 ArrayRef<OpFoldResult> sizes) {
2929 int64_t rank = getInputOperandRank();
2931 SmallVector<OpFoldResult> strides(rank, oneAttr);
2932 SmallVector<Value> tiledOperands;
2933 Operation *inputSlice =
2934 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2936 return emitOpError(
"failed to compute input slice");
2938 tiledOperands.emplace_back(inputSlice->
getResult(0));
2939 Operation *outputSlice =
2940 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2942 return emitOpError(
"failed to compute output slice");
2944 tiledOperands.emplace_back(outputSlice->
getResult(0));
2946 SmallVector<Type, 4> resultTypes;
2947 if (hasPureTensorSemantics())
2948 resultTypes.push_back(tiledOperands[1].
getType());
2949 Operation *tiledOp =
2950 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2952 return TilingResult{
2955 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2958LogicalResult SoftmaxOp::getResultTilePosition(
2959 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2960 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2961 SmallVector<OpFoldResult> &resultSizes) {
2962 if (resultNumber == 0) {
2963 resultOffsets.assign(offsets.begin(), offsets.end());
2964 resultSizes.assign(sizes.begin(), sizes.end());
2971LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2976SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2978 SmallVector<OpFoldResult> shapes;
2979 Location loc = getOperation()->getLoc();
2980 IRRewriter rewriter(
b);
2981 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2982 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2983 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2984 if (!outputShapedType.isDynamicDim(dim)) {
2986 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2993 reifiedReturnShapes.emplace_back(std::move(shapes));
2997void SoftmaxOp::getEffects(
2998 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3000 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
3001 if (!llvm::isa<MemRefType>(operand.
getType()))
3004 &getOperation()->getOpOperand(index), 0,
3009 for (OpOperand &operand : getDpsInitsMutable()) {
3010 if (!llvm::isa<MemRefType>(operand.get().
getType()))
3041static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
3043 int64_t dim,
bool allParallel =
false) {
3045 utils::IteratorType::parallel);
3047 iteratorTypes[dim] = utils::IteratorType::reduction;
3051 for (
int i = 0; i < inputRank; i++) {
3058 return std::make_tuple(iteratorTypes, indexingMaps);
3063template <
typename T>
3066 auto inputType = cast<ShapedType>(input.
getType());
3068 int64_t inputRank = inputShape.size();
3069 auto [iteratorTypes, indexingMaps] =
3071 assert(indexingMaps.size() == 2 &&
3072 "We should have two maps: 1 for the input, 1 for the output");
3073 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3075 auto genericOp = linalg::GenericOp::create(
3076 builder, loc, output.
getType(), input, output, indexingMaps,
3078 Value result = T::create(b, loc, args[0], args[1]);
3079 linalg::YieldOp::create(b, loc, result);
3081 return genericOp.getResult(0);
3089 auto inputType = cast<ShapedType>(input.
getType());
3091 int64_t inputRank = inputShape.size();
3093 builder, inputRank, dim,
true);
3094 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3095 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3097 indexingMaps.push_back(indexingMaps[0]);
3098 auto genericOp = linalg::GenericOp::create(
3100 indexingMaps, iteratorTypes,
3102 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3103 Value result = math::ExpOp::create(b, loc, diff);
3104 linalg::YieldOp::create(b, loc, result);
3106 return genericOp.getResult(0);
3116 auto inputType = cast<ShapedType>(numerator.
getType());
3118 int64_t inputRank = inputShape.size();
3120 builder, inputRank, dim,
true);
3121 assert(indexingMaps.size() == 2 &&
3122 "We should have one map for each input (2)");
3123 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3125 indexingMaps.push_back(indexingMaps[0]);
3126 auto genericOp = linalg::GenericOp::create(
3128 output, indexingMaps, iteratorTypes,
3130 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3131 linalg::YieldOp::create(b, loc, result);
3133 return genericOp.getResult(0);
3155FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3156 OpBuilder::InsertionGuard guard(
b);
3157 b.setInsertionPoint(*
this);
3158 Location loc = getLoc();
3159 Value input = getInput();
3160 ShapedType inputType = getInputOperandType();
3161 Type elementType = inputType.getElementType();
3162 int64_t reductionDim = getDimension();
3164 Value output = getOutput();
3165 dims.erase(dims.begin() + reductionDim);
3167 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3169 elementType,
b, loc,
3171 Value neutralForMaxFInit =
3172 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3184 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3190 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3191 return SmallVector<Value>{
result};
3198LogicalResult WinogradFilterTransformOp::verify() {
3199 auto filterType = cast<ShapedType>(getFilter().
getType());
3200 ArrayRef<int64_t> filterShape = filterType.getShape();
3201 int64_t filterH = filterShape[getFilterHDim()];
3202 int64_t filterW = filterShape[getFilterWDim()];
3203 WinogradConv2DFmr fmr = getFmr();
3207 if (filterH != r && filterH != 1)
3208 return emitOpError(
"expect filter height either equals to r or 1");
3209 if (filterW != r && filterW != 1)
3210 return emitOpError(
"expect filter width either equals to r or 1");
3211 if (filterH == 1 && filterW == 1)
3212 return emitOpError(
"expect either filter height or width equals to r");
3214 SmallVector<int64_t> expectedOutputShape;
3215 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3216 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3217 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3218 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3220 auto outputType = cast<ShapedType>(getOutput().
getType());
3221 ArrayRef<int64_t> outputShape = outputType.getShape();
3223 return emitOpError(
"the output shape is not expected");
3229WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3230 Location loc = getLoc();
3233 Value filter = getFilter();
3234 int64_t filterRank = getFilterOperandRank();
3235 SmallVector<Range> loopBounds(filterRank);
3236 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3237 loopBounds[dim].offset = zeroAttr;
3238 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3239 loopBounds[dim].stride = oneAttr;
3244SmallVector<utils::IteratorType>
3245WinogradFilterTransformOp::getLoopIteratorTypes() {
3246 int64_t filterRank = getFilterOperandRank();
3247 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3248 utils::IteratorType::parallel);
3249 return iteratorTypes;
3252LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3253 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3254 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3255 SmallVector<OpFoldResult> &resultSizes) {
3257 ShapedType filterType = getFilterOperandType();
3258 ArrayRef<int64_t> filterShape = filterType.getShape();
3259 int64_t filterH = filterShape[getFilterHDim()];
3260 int64_t filterW = filterShape[getFilterWDim()];
3261 WinogradConv2DFmr fmr = getFmr();
3264 int64_t alpha = m + r - 1;
3265 int64_t alphaH = filterH != 1 ? alpha : 1;
3266 int64_t alphaW = filterW != 1 ? alpha : 1;
3270 resultOffsets.append(
3271 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3273 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3284FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3285 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3286 ArrayRef<OpFoldResult> sizes) {
3289 ShapedType filterType = getFilterOperandType();
3290 ArrayRef<int64_t> filterShape = filterType.getShape();
3291 int64_t filterH = filterShape[getFilterHDim()];
3292 int64_t filterW = filterShape[getFilterWDim()];
3295 SmallVector<Value> tiledOperands;
3296 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3298 sliceOffsets.append(
3299 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3300 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3301 sizes[getFilterCDim()]});
3302 int64_t filterRank = getFilterOperandRank();
3303 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3304 Location loc = getLoc();
3305 auto filterSlice = tensor::ExtractSliceOp::create(
3306 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3307 tiledOperands.emplace_back(filterSlice);
3309 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3314 int64_t outputRank = getOutputOperandRank();
3315 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3316 auto outputSlice = tensor::ExtractSliceOp::create(
3317 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3318 tiledOperands.emplace_back(outputSlice);
3320 SmallVector<Type> resultTypes;
3321 resultTypes.push_back(tiledOperands[1].
getType());
3322 Operation *tiledOp =
3323 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3325 return TilingResult{
3328 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3335LogicalResult WinogradInputTransformOp::verify() {
3336 auto inputType = cast<ShapedType>(getInput().
getType());
3337 ArrayRef<int64_t> inputShape = inputType.getShape();
3338 int64_t inputH = inputShape[getInputHDim()];
3339 int64_t inputW = inputShape[getInputWDim()];
3340 WinogradConv2DFmr fmr = getFmr();
3343 int64_t tileSize = m + r - 1;
3345 auto outputType = cast<ShapedType>(getOutput().
getType());
3346 ArrayRef<int64_t> outputShape = outputType.getShape();
3347 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3348 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3350 SmallVector<int64_t> expectedOutputShape(6, inputH);
3351 if (ShapedType::isDynamic(inputH)) {
3352 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3353 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3355 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3356 expectedOutputShape[getOutputTileHDim()] =
3357 leftTransform ? (inputH - (r - 1)) / m : inputH;
3359 if (ShapedType::isDynamic(inputW)) {
3360 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3361 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3363 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3364 expectedOutputShape[getOutputTileWDim()] =
3365 rightTransform ? (inputW - (r - 1)) / m : inputW;
3367 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3368 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3371 return emitOpError(
"the output shape is not expected");
3377WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3378 Location loc = getLoc();
3381 Value output = getOutput();
3382 int64_t outputRank = getOutputOperandRank();
3383 SmallVector<Range> loopBounds(outputRank);
3384 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3385 loopBounds[dim].offset = zeroAttr;
3387 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3388 loopBounds[dim].stride = oneAttr;
3393SmallVector<utils::IteratorType>
3394WinogradInputTransformOp::getLoopIteratorTypes() {
3395 int64_t outputRank = getOutputOperandRank();
3396 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3397 utils::IteratorType::parallel);
3398 return iteratorTypes;
3401LogicalResult WinogradInputTransformOp::getResultTilePosition(
3402 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3403 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3404 SmallVector<OpFoldResult> &resultSizes) {
3406 ShapedType outputType = getOutputOperandType();
3407 ArrayRef<int64_t> outputShape = outputType.getShape();
3408 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3409 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3411 WinogradConv2DFmr fmr = getFmr();
3414 int64_t alpha = m + r - 1;
3415 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3416 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3421 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3422 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3423 offsets[getOutputCDim()]});
3424 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3425 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3426 sizes[getOutputCDim()]});
3437FailureOr<TilingResult>
3438WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3439 ArrayRef<OpFoldResult> offsets,
3440 ArrayRef<OpFoldResult> sizes) {
3442 WinogradConv2DFmr fmr = getFmr();
3446 ShapedType outputType = getOutputOperandType();
3447 ArrayRef<int64_t> outputShape = outputType.getShape();
3448 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3449 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3451 Location loc = getLoc();
3453 auto identityAffineMap =
3455 auto offsetAffineMap =
3458 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3459 offsets[getOutputTileHDim()]);
3461 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3462 offsets[getOutputTileWDim()]);
3466 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3468 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3470 SmallVector<Value> tiledOperands;
3471 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3473 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3474 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3475 sliceOffsets.append(
3476 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3477 OpFoldResult sizeH =
3478 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3479 OpFoldResult sizeW =
3480 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3482 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3483 int64_t inputRank = getInputOperandRank();
3484 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3485 auto inputSlice = tensor::ExtractSliceOp::create(
3486 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3487 tiledOperands.emplace_back(inputSlice);
3489 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3494 int64_t outputRank = getOutputOperandRank();
3495 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3496 auto outputSlice = tensor::ExtractSliceOp::create(
3497 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3498 tiledOperands.emplace_back(outputSlice);
3500 SmallVector<Type> resultTypes;
3501 resultTypes.push_back(tiledOperands[1].
getType());
3502 Operation *tiledOp =
3503 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3505 return TilingResult{
3508 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3515LogicalResult WinogradOutputTransformOp::verify() {
3516 auto valueType = cast<ShapedType>(getValue().
getType());
3517 ArrayRef<int64_t> valueShape = valueType.getShape();
3518 int64_t valueH = valueShape[getValueAlphaHDim()];
3519 int64_t valueW = valueShape[getValueAlphaWDim()];
3520 int64_t valueTileH = valueShape[getValueTileHDim()];
3521 int64_t valueTileW = valueShape[getValueTileWDim()];
3522 WinogradConv2DFmr fmr = getFmr();
3525 bool leftTransform = valueH != 1;
3526 bool rightTransform = valueW != 1;
3528 int64_t outputRank = getOutputOperandRank();
3529 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3530 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3531 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3533 if (valueH != (leftTransform ? m + r - 1 : 1))
3534 return emitOpError(
"expect input height equals to input tile size");
3535 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3537 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3538 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3540 if (valueW != (rightTransform ? m + r - 1 : 1))
3541 return emitOpError(
"expect input width equals to input tile size");
3542 expectedOutputShape[getOutputWDim()] =
3543 (rightTransform ? m : 1) * valueTileW;
3545 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3546 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3548 auto outputType = cast<ShapedType>(getOutput().
getType());
3549 ArrayRef<int64_t> outputShape = outputType.getShape();
3551 return emitOpError(
"the output shape is not expected");
3557WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3558 Location loc = getLoc();
3561 Value value = getValue();
3562 int64_t valueRank = getValueOperandRank();
3563 SmallVector<Range> loopBounds(valueRank);
3564 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3565 loopBounds[dim].offset = zeroAttr;
3567 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3568 loopBounds[dim].stride = oneAttr;
3573SmallVector<utils::IteratorType>
3574WinogradOutputTransformOp::getLoopIteratorTypes() {
3575 int64_t valueRank = getValueOperandRank();
3576 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3577 utils::IteratorType::parallel);
3578 return iteratorTypes;
3581LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3582 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3583 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3584 SmallVector<OpFoldResult> &resultSizes) {
3585 WinogradConv2DFmr fmr = getFmr();
3589 Location loc = getLoc();
3591 auto identityAffineMap =
3596 ShapedType valueType = getValueOperandType();
3597 ArrayRef<int64_t> valueShape = valueType.getShape();
3598 int64_t valueH = valueShape[0];
3599 int64_t valueW = valueShape[1];
3601 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3602 offsets[getValueTileHDim()]);
3604 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3605 offsets[getValueTileWDim()]);
3607 builder, loc, affineMap, sizes[getValueTileHDim()]);
3609 builder, loc, affineMap, sizes[getValueTileWDim()]);
3612 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3613 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3614 OpFoldResult sizeH =
3615 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3616 OpFoldResult sizeW =
3617 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3619 resultOffsets.append(
3620 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3622 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3632FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3633 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3634 ArrayRef<OpFoldResult> sizes) {
3637 Location loc = getLoc();
3638 SmallVector<Value> tiledOperands;
3639 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3641 ShapedType valueType = getValueOperandType();
3642 ArrayRef<int64_t> valueShape = valueType.getShape();
3643 int64_t alphaH = valueShape[getValueAlphaHDim()];
3644 int64_t alphaW = valueShape[getValueAlphaWDim()];
3648 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3649 offsets[getValueTileWDim()], offsets[getValueNDim()],
3650 offsets[getValueFDim()]});
3651 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3652 sizes[getValueTileWDim()], sizes[getValueNDim()],
3653 sizes[getValueFDim()]});
3654 int64_t valueRank = getValueOperandRank();
3655 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3656 auto valueSlice = tensor::ExtractSliceOp::create(
3657 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3658 tiledOperands.emplace_back(valueSlice);
3660 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3665 int64_t outputRank = getOutputOperandRank();
3666 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3667 auto outputSlice = tensor::ExtractSliceOp::create(
3668 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3669 tiledOperands.emplace_back(outputSlice);
3671 SmallVector<Type> resultTypes;
3672 resultTypes.push_back(tiledOperands[1].
getType());
3673 Operation *tiledOp =
3674 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3676 return TilingResult{
3679 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3693 llvm::set_union(explicitSet, defaultSet);
3694 return explicitSet == defaultSet;
3714 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3716 auto opIndexingMap = opIndexingMaps[opIndex];
3717 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3720 return matmulOp->emitOpError()
3721 <<
"Unexpected dim expression in map result.";
3724 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3725 return matmulOp->emitOpError()
3726 <<
"Invalid broadcast requested, should be (d2).";
3735template <
typename OpTy>
3738 AffineMap defaultIndexingMap,
bool isLHS) {
3739 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3740 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3741 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3744 return batchVariantMatmulOp->emitOpError()
3745 <<
"Unexpected result dim expression (outside the set of default "
3750 return batchVariantMatmulOp->emitOpError()
3751 <<
"no. of result dim expressions exceeds 3.";
3753 auto hasValidBatchDim = [](
AffineMap map) {
3760 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3761 return batchVariantMatmulOp->emitOpError()
3762 <<
"Invalid broadcast requested.";
3763 }
else if (!hasValidBatchDim(opIndexingMap)) {
3764 return batchVariantMatmulOp->emitOpError()
3765 <<
"Invalid batch dimension expression.";
3773template <
typename OpTy>
3776 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3777 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3778 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3779 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3782 return batchVariantMatmulOp->emitOpError()
3783 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3786 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3788 return batchVariantMatmulOp->emitOpError()
3789 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3793 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3794 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3795 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3796 outputMap.getResult(1).isFunctionOfDim(1) &&
3797 outputMap.getResult(2).isFunctionOfDim(2)
3798 : outputMap.getResult(0).isFunctionOfDim(1) &&
3799 outputMap.getResult(1).isFunctionOfDim(2);
3802 if (!areValidOutputResultDim(opIndexingMap)) {
3803 return batchVariantMatmulOp->emitOpError()
3804 <<
"Invalid output map result dimension.";
3813template <
typename OpTy>
3818 batchVariantMatmulOp.getIndexingMapsArray();
3820 batchVariantMatmulOp.getDefaultIndexingMaps(
3821 batchVariantMatmulOp->getContext());
3823 if (opIndexingMaps.size() != 3)
3824 return batchVariantMatmulOp->emitOpError()
3825 <<
"Indexing_map attribute must have 3 affine maps.";
3827 auto opIndexingMap = opIndexingMaps[opIndex];
3828 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3836 defaultIndexingMap, opIndex == 0)))
3846 if (m == 2 && r == 3)
3847 return WinogradConv2DFmr::F_2_3;
3848 if (m == 4 && r == 3)
3849 return WinogradConv2DFmr::F_4_3;
3850 if (m == 2 && r == 5)
3851 return WinogradConv2DFmr::F_2_5;
3852 return std::nullopt;
3857 case WinogradConv2DFmr::F_2_3:
3859 case WinogradConv2DFmr::F_4_3:
3861 case WinogradConv2DFmr::F_2_5:
3864 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3871static FailureOr<SmallVector<SmallVector<int64_t>>>
3874 for (
auto map : maps) {
3875 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3879 for (
auto result : attr.getAffineMap().getResults()) {
3880 auto dim = dyn_cast<AffineDimExpr>(
result);
3883 pos.push_back(dim.getPosition());
3885 positions.push_back(pos);
3898 return indexingMaps;
3901bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3902 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3905 if (maps.size() != 3)
3910 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3911 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3912 (*positions)[2] == SmallVector<int64_t>{0, 1};
3915SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3916 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3917 utils::IteratorType::parallel,
3918 utils::IteratorType::reduction};
3921unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3923std::string MatmulOp::getLibraryCallName() {
3927bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3931bool MatmulOp::hasUserDefinedMaps() {
3932 SmallVector<AffineMap, 3> defaultMaps =
3934 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3935 return defaultMaps != explicitMaps;
3940void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3941 ArrayRef<NamedAttribute> attrs,
3944 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3949 "MatmulOp regionBuilder expects 3 args");
3950 RegionBuilderHelper helper(
b, block);
3951 SmallVector<Value> yields;
3953 TypeFn castVal = TypeFn::cast_signed;
3954 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3955 return attr.
getName() ==
"cast";
3957 if (castIter != attrs.end()) {
3958 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3966 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3967 if (!value1 || !value2 || !value3)
3969 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3973 yields.push_back(value4);
3974 helper.yieldOutputs(yields);
3984bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3985 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3986 AffineExpr expr = bcastMap.
getResult(0);
3996 ArrayAttr arrayAttr;
4000 if (llvm::any_of(arrayAttr,
4001 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
4003 <<
"element of indexing_maps array is not an affine_map";
4010 if (failed(indexingMapsAttr))
4013 if (*indexingMapsAttr ==
nullptr) {
4014 auto indexingMapAttrs = llvm::map_to_vector(
4015 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4020 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4022 MatmulOp::getRegionBuilder());
4025void MatmulOp::print(OpAsmPrinter &p) {
4026 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4027 MatmulOp::getDefaultIndexingMaps(
getContext()),
4028 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4029 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4030 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4032 std::array<StringRef, 3> elidedAttrs = {
4033 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4039LogicalResult MatmulOp::verify() {
4041 if (!hasUserDefinedMaps())
4044 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
4051LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4055void MatmulOp::getEffects(
4056 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4058 if (hasPureTensorSemantics())
4067SmallVector<AffineMap>
4068MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4069 AffineExpr d0, d1, d2;
4075 return {mapLHS, mapRHS, mapOut};
4079 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4082 if (maps.size() != 3)
4085 if (failed(positions))
4097 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4105 build(builder, state, inputs, outputs, attributes);
4106 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4107 assert(res &&
"builder didn't return the right type");
4117 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4126 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4127 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4128 assert(res &&
"builder didn't return the right type");
4138 result.addAttribute(
"cast", cast);
4140 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4149 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4150 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4151 assert(res &&
"builder didn't return the right type");
4156 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4158 op->
getAttr(
"indexing_maps"));
4162MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4169 return {mapLHS, mapRHS, mapOut};
4173 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4176 if (maps.size() != 3)
4179 if (failed(positions))
4191 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4199 build(builder, state, inputs, outputs, attributes);
4200 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4201 assert(res &&
"builder didn't return the right type");
4211 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4220 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4221 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4222 assert(res &&
"builder didn't return the right type");
4232 result.addAttribute(
"cast", cast);
4234 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4243 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4244 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4245 assert(res &&
"builder didn't return the right type");
4250 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4252 op->
getAttr(
"indexing_maps"));
4256BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4263 return {mapLHS, mapRHS, mapOut};
4267 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4270 if (maps.size() != 3)
4273 if (failed(positions))
4284 BatchMatmulOp::getRegionBuilder(),
4285 getDefaultIndexingMaps(builder));
4293 build(builder, state, inputs, outputs, attributes);
4294 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4295 assert(res &&
"builder didn't return the right type");
4304 BatchMatmulOp::getRegionBuilder(),
4305 getDefaultIndexingMaps(builder));
4314 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4315 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4316 assert(res &&
"builder didn't return the right type");
4324 result.addAttribute(
"cast", cast);
4326 BatchMatmulOp::getRegionBuilder(),
4327 getDefaultIndexingMaps(builder));
4336 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4337 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4338 assert(res &&
"builder didn't return the right type");
4343 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4345 op->
getAttr(
"indexing_maps"));
4349BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4356 return {mapLHS, mapRHS, mapOut};
4360 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4363 if (maps.size() != 3)
4366 if (failed(positions))
4377 BatchMatmulOp::getRegionBuilder(),
4378 getDefaultIndexingMaps(builder));
4386 build(builder, state, inputs, outputs, attributes);
4387 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4388 assert(res &&
"builder didn't return the right type");
4397 BatchMatmulOp::getRegionBuilder(),
4398 getDefaultIndexingMaps(builder));
4407 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4408 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4409 assert(res &&
"builder didn't return the right type");
4417 result.addAttribute(
"cast", cast);
4419 BatchMatmulOp::getRegionBuilder(),
4420 getDefaultIndexingMaps(builder));
4429 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4430 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4431 assert(res &&
"builder didn't return the right type");
4436 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4438 op->
getAttr(
"indexing_maps"));
4446 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4457 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4458 assert(dimExpr &&
"affine_map is a projected permutation");
4459 dimsInOutput[dimExpr.getPosition()] =
true;
4463 for (
auto dimOccursInOutput : dimsInOutput)
4464 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4465 : utils::IteratorType::reduction);
4467 return iteratorTypes;
4470unsigned ContractOp::getNumRegionArgs() {
return 3; }
4473void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4474 ArrayRef<NamedAttribute> attrs,
4477 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4482 "ContractOp regionBuilder expects 3 args");
4483 RegionBuilderHelper helper(
b, block);
4485 TypeFn castSignedness = TypeFn::cast_signed;
4486 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4487 return attr.
getName() ==
"cast";
4489 if (castIter != attrs.end()) {
4490 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4496 Value lhsAtOutType =
4497 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4498 Value rhsAtOutType =
4499 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4500 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4502 if (!productAtOutType)
4508 helper.yieldOutputs({
result});
4511ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4513 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4515 "expected 'indexing_maps' attribute");
4516 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4522void ContractOp::print(OpAsmPrinter &p) {
4523 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4525 p, getOperation(), getInputs(), getOutputs(),
4526 {
"indexing_maps",
"operandSegmentSizes"});
4529LogicalResult ContractOp::verify() {
4530 int iterationSpaceDims = -1;
4535 SmallVector<size_t> inOccurrences;
4536 SmallVector<size_t> outOccurrences;
4539 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4540 bool isInput) -> LogicalResult {
4543 return emitError(
"provided affine_map is not a projected permutation");
4546 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4548 return emitError(
"ranks of shaped operand and results of corresponding "
4549 "affine_map differ");
4551 return emitError(
"affine_map specifies shaped access while operand has "
4556 if (iterationSpaceDims == -1) {
4558 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4559 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4560 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4561 return emitError(
"iteration spaces of provided affine_maps differ");
4565 for (AffineExpr affineExpr : affineMap.
getResults()) {
4566 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4568 llvm_unreachable(
"affine_map is a projected permutation");
4571 inOccurrences[affineDimExpr.getPosition()] += 1;
4573 outOccurrences[affineDimExpr.getPosition()] += 1;
4579 for (
auto &&[affineMap, operandType, isInput] :
4580 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4581 SmallVector<bool>{
true,
true,
false})) {
4582 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4586 bool hasContractingDim =
false;
4587 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4588 size_t inOccCount = inOccurrences[dimIndex];
4589 size_t outOccCount = outOccurrences[dimIndex];
4592 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4594 if (inOccCount == 0 && outOccCount == 0)
4595 return emitError() <<
"iteration space dim at index " << dimIndex
4596 <<
" not used to access any operand";
4607 if (inOccCount == 1 && outOccCount != 1)
4609 <<
"iteration space dim at index " << dimIndex
4610 <<
" is neither a contracting dim nor of parallel iteration type";
4613 if (!hasContractingDim)
4614 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4619LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4623void ContractOp::getEffects(
4624 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4626 if (hasPureTensorSemantics())
4638SmallVector<AffineMap>
4639BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4640 AffineExpr d0, d1, d2, d3;
4641 SmallVector<AffineMap> indexingMaps;
4643 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4644 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4645 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4646 return indexingMaps;
4649bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4650 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4653 if (maps.size() != 3)
4658 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4659 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4660 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4663SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4664 return SmallVector<utils::IteratorType>{
4665 utils::IteratorType::parallel, utils::IteratorType::parallel,
4666 utils::IteratorType::parallel, utils::IteratorType::reduction};
4669unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4671std::string BatchMatmulOp::getLibraryCallName() {
4677bool BatchMatmulOp::hasUserDefinedMaps() {
4678 SmallVector<AffineMap, 3> defaultMaps =
4680 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4681 return defaultMaps != explicitMaps;
4691bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4693 "Expected less than 3 result dim expr.");
4694 bool isValid =
false;
4695 enum Indices { batchPos, mPos, nPos, kPos };
4697 AffineExpr expr = bcastMap.
getResult(0);
4700 AffineExpr expr0 = bcastMap.
getResult(0);
4701 AffineExpr expr1 = bcastMap.
getResult(1);
4706 : ((expr0.isFunctionOfDim(batchPos) &&
4707 expr1.isFunctionOfDim(kPos)) ||
4708 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4713void BatchMatmulOp::regionBuilder(
4714 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4717 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4722 "BatchMatmulOp regionBuilder expects 3 args");
4723 RegionBuilderHelper helper(
b, block);
4724 SmallVector<Value> yields;
4726 TypeFn castVal = TypeFn::cast_signed;
4727 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4728 return attr.
getName() ==
"cast";
4730 if (castIter != attrs.end()) {
4731 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4736 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4737 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4739 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4740 if (!castValA || !castValB || !mulVal)
4742 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4746 yields.push_back(addVal);
4747 helper.yieldOutputs(yields);
4750ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4751 SmallVector<Attribute, 3> indexingMapsAttr;
4763 if (!isa<AffineMapAttr>(mapAttr)) {
4765 "expected affine map attribute");
4767 indexingMapsAttr.push_back(mapAttr);
4777 if (indexingMapsAttr.empty()) {
4778 indexingMapsAttr = llvm::map_to_vector(
4779 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4780 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4782 result.addAttribute(
"indexing_maps",
4785 return ::parseNamedStructuredOp(parser,
result,
4786 BatchMatmulOp::getNumRegionArgs(),
4787 BatchMatmulOp::getRegionBuilder());
4790void BatchMatmulOp::print(OpAsmPrinter &p) {
4791 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4792 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4793 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4794 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4795 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4797 std::array<StringRef, 3> elidedAttrs = {
4798 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4804LogicalResult BatchMatmulOp::verify() {
4807 if (!hasUserDefinedMaps())
4810 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4817LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4818 SmallVectorImpl<OpFoldResult> &) {
4822void BatchMatmulOp::getEffects(
4823 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4825 if (hasPureTensorSemantics())
4839struct ArityGroupAndKind {
4841 ElementwiseArityGroup arityGroup;
4847 TernaryFn ternaryFn;
4851unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4852 return static_cast<unsigned>(arityGroup);
4857 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4858 constexpr int lastBinary =
4859 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4860 constexpr int lastTernary =
4861 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4863 int val =
static_cast<int>(kind);
4864 ArityGroupAndKind
result;
4866 if (val < lastUnary) {
4867 result.arityGroup = ElementwiseArityGroup::Unary;
4868 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4871 if (val < lastBinary) {
4872 result.arityGroup = ElementwiseArityGroup::Binary;
4873 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4876 if (val >= lastTernary) {
4877 llvm_unreachable(
"unhandled ElementwiseFn");
4879 result.arityGroup = ElementwiseArityGroup::Ternary;
4880 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4885 auto rank = getResultRank();
4890ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4896ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4899 mlir::linalg::ElementwiseKind elemwiseKindVal;
4904 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4905 if (!elemwiseKindAttr)
4907 "expected ElementwiseKind attribute");
4908 elemwiseKindVal = elemwiseKindAttr.getValue();
4911 "expected operation 'kind' attribute");
4914 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4917 SmallVector<Attribute, 3> indexingMapsAttr;
4927 if (!isa<AffineMapAttr>(mapAttr))
4929 "expected affine map attribute");
4930 indexingMapsAttr.push_back(mapAttr);
4941 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4943 ElementwiseOp::getRegionBuilder())) {
4945 "unable to parse elemwise op");
4949 if (indexingMapsAttr.empty()) {
4952 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4953 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4956 "return type needs to be shaped type");
4957 auto numDims = shapedType.getRank();
4958 indexingMapsAttr = llvm::map_to_vector(
4959 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4961 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4964 result.addAttribute(
"indexing_maps",
4969void ElementwiseOp::print(OpAsmPrinter &p) {
4972 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4976 unsigned numDims = getResultRank();
4978 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4979 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4981 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4983 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4984 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4992void ElementwiseOp::regionBuilder(
4993 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4995 ElementwiseKind elemwiseKind;
4996 for (
auto attr : attrs) {
4997 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4998 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4999 assert(kindAttr &&
"op kind attribute incorrectly set");
5000 elemwiseKind = kindAttr.getValue();
5006 auto arityGroup = groupAndKind.arityGroup;
5007 auto kind = groupAndKind.kind;
5009 getArityGroupAsUInt(arityGroup) + 1 ) {
5010 emitError() <<
"Elementwise regionBuilder expects "
5011 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
5016 getArityGroupAsUInt(arityGroup) + 1
5017 &&
"Elementwise regionBuilder number of block args mismatch");
5019 RegionBuilderHelper helper(
b, block);
5020 SmallVector<Value> yields;
5023 if (arityGroup == ElementwiseArityGroup::Unary) {
5026 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
5030 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
5035 assert(
false &&
"found unhandled category in elemwise");
5038 yields.push_back(
result);
5039 helper.yieldOutputs(yields);
5042LogicalResult ElementwiseOp::fold(FoldAdaptor,
5043 SmallVectorImpl<OpFoldResult> &) {
5047void ElementwiseOp::getEffects(
5048 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5050 if (hasPureTensorSemantics())
5063template <
typename OpTy,
typename>
5066 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5067 ? packOrUnPack.getDestType()
5068 : packOrUnPack.getSourceType();
5069 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5070 ? packOrUnPack.getSourceType()
5071 : packOrUnPack.getDestType();
5073 packedType.getShape().take_front(unpackedType.getRank()));
5074 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5095 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5097 .take_back(mixedTiles.size()),
5099 int64_t dimSize = std::get<0>(it);
5100 if (dimSize == ShapedType::kDynamic) {
5101 newMixedTileSizes.push_back(std::get<1>(it));
5104 newMixedTileSizes.push_back(rewriter.
getIndexAttr(dimSize));
5107 return newMixedTileSizes;
5110template <
typename OpTy>
5114 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5115 "applies to only pack or unpack operations");
5116 int64_t destRank = op.getDestRank();
5118 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5119 reifiedReturnShapes[0][dim] =
5124template <
typename OpTy>
5126 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5127 "applies to only pack or unpack operations");
5131 assert(tiles.size() == dimsToTile.size() &&
5132 "tiles must match indices of dimension to block");
5134 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5135 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5136 return dimAndTileMapping;
5139template <
typename OpTy>
5141 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5142 "applies to only pack or unpack operations");
5145 unsigned dynamicValIndex = 0;
5146 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5147 if (ShapedType::isStatic(staticTile))
5150 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5152 return mixedInnerTiles;
5155template <
typename OpTy>
5157 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5158 "applies to only pack or unpack operations");
5171 size_t dimsPosSize = dimsPos.size();
5172 if (dimsPosSize > rank)
5175 if (dimsPosSize != uniqued.size())
5177 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5178 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5182template <
typename OpTy>
5184 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5185 "applies to only pack or unpack operations");
5186 Operation *op = packOrUnPack.getOperation();
5196 if (!packOrUnPack.getSourceType().hasRank() ||
5197 !packOrUnPack.getDestType().hasRank())
5198 return op->
emitError(
"expected both source and destination to have rank");
5201 if (!packOrUnPack.hasPureBufferSemantics() &&
5202 !packOrUnPack.hasPureTensorSemantics())
5203 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5204 const unsigned numResults = packOrUnPack.getNumResults();
5205 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5206 return op->
emitError(
"expected 1 result, got ") << numResults;
5207 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5208 return op->
emitError(
"expected 0 results, got ") << numResults;
5212 if (hasZeros(mixedTiles))
5213 return op->
emitError(
"invalid zero tile factor");
5216 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5217 ? packOrUnPack.getSourceType()
5218 : packOrUnPack.getDestType();
5219 size_t unpackedRank = unpackedType.getRank();
5223 return op->
emitError(
"invalid inner_dims_pos vector");
5225 return op->
emitError(
"invalid outer_dims_perm vector");
5226 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5227 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5231 if (mixedTiles.size() > unpackedRank) {
5232 return op->
emitError(
"tiling factors must be less than or equal to the "
5233 "input rank for pack or output rank for unpack");
5235 if (mixedTiles.size() != innerDimsPos.size()) {
5237 "tiling factors must equal the number of dimensions to tile");
5240 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5241 ? packOrUnPack.getDestType()
5242 : packOrUnPack.getSourceType();
5243 size_t packedRank = packedType.getRank();
5245 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5246 if (expectedPackedRank != packedRank) {
5248 "packed rank != (unpacked rank + num tiling factors), got ")
5249 << packedRank <<
" != " << expectedPackedRank;
5256 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5257 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5258 for (
auto it : llvm::enumerate(llvm::zip(
5259 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5260 int64_t dimSize = std::get<0>(it.value());
5262 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5263 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5264 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5265 if (dimSize != staticTileSize)
5267 "mismatch in inner tile sizes specified and shaped of "
5268 "tiled dimension in the packed type at index ")
5269 << it.index() <<
": got " << dimSize <<
" != " << staticTileSize;
5270 }
else if (!ShapedType::isDynamic(dimSize)) {
5271 return op->
emitError(
"mismatch in inner tile sizes specified at index ")
5272 << it.index() <<
": got static shape " << dimSize
5273 <<
" but dynamic tile size";
5278 auto elementType = unpackedType.getElementType();
5279 Type expectedType, actualType;
5280 if (packOrUnPack.hasPureTensorSemantics()) {
5281 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5282 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5284 expectedType = MemRefType::get(expectedPackedShape, elementType);
5285 actualType = MemRefType::get(packedType.getShape(), elementType);
5288 << expectedType <<
" for the packed domain value, got "
5301struct PackOrUnPackTransposeResult {
5308template <
typename OpTy>
5309static PackOrUnPackTransposeResult
5313 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5314 "applies to only pack or unpack operations");
5315 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5316 "some permutation must be non-empty");
5317 PackOrUnPackTransposeResult metadata;
5318 metadata.innerDimsPos =
5320 metadata.innerTiles =
5322 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5323 ? packOrUnPackOp.getSourceRank()
5324 : packOrUnPackOp.getDestRank();
5325 metadata.outerDimsPerm =
5326 packOrUnPackOp.getOuterDimsPerm().empty()
5327 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5329 if (!innerPermutation.empty()) {
5330 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5332 "invalid inner permutation");
5336 if (!outerPermutation.empty()) {
5337 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5339 "invalid outer permutation");
5350 if (!getResults().empty())
5351 setNameFn(getResult(),
"pack");
5361 Type sourceType, destType, resultType;
5378 SmallVector<int64_t> outerDimsPermVec;
5381 if (parser.parseInteger(value))
5383 outerDimsPermVec.push_back(value);
5393 SmallVector<int64_t> innerDimsPosVec;
5396 if (parser.parseInteger(value))
5398 innerDimsPosVec.push_back(value);
5410 for (
auto val : staticTilesAttr.
asArrayRef())
5411 staticTiles.push_back(val);
5428 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5431 "pack/unpack requires '->' and destination type");
5435 resultType = destType;
5441 if (!paddingValue.empty() &&
5446 if (!dynamicTiles.empty() &&
5451 result.addAttribute(
"static_inner_tiles",
5453 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5455 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5457 SmallVector<int32_t> segmentSizes = {
5458 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5459 static_cast<int32_t
>(dynamicTiles.size())};
5460 result.addAttribute(
"operandSegmentSizes",
5464 result.addTypes(resultType);
5469void PackOp::print(OpAsmPrinter &p) {
5470 p <<
" " << getSource();
5472 if (getPaddingValue()) {
5473 p <<
" padding_value(" << getPaddingValue() <<
" : "
5474 << getPaddingValue().getType() <<
")";
5477 if (!getOuterDimsPerm().empty()) {
5478 p <<
" outer_dims_perm = [";
5479 llvm::interleaveComma(getOuterDimsPerm(), p);
5483 p <<
" inner_dims_pos = [";
5484 llvm::interleaveComma(getInnerDimsPos(), p);
5487 p <<
" inner_tiles = ";
5490 p <<
" into " << getDest();
5493 {
"static_inner_tiles",
"inner_dims_pos",
5494 "outer_dims_perm",
"operandSegmentSizes"});
5496 p <<
" : " << getSource().getType();
5497 p <<
" -> " << getDest().getType();
5500void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5501 Value dest, ArrayRef<int64_t> innerDimsPos,
5502 ArrayRef<OpFoldResult> innerTiles,
5503 std::optional<Value> paddingValue,
5504 ArrayRef<int64_t> outerDimsPerm) {
5505 assert(innerDimsPos.size() == innerTiles.size() &&
5506 "number of tile sizes specified must match the specified number of "
5507 "original dimensions to be tiled");
5508 SmallVector<int64_t> staticTileSizes;
5509 SmallVector<Value> dynamicTileSizes;
5511 build(builder, state, dest.
getType(), source, dest,
5512 paddingValue ? *paddingValue :
nullptr,
5513 outerDimsPerm.empty() ?
nullptr
5520PackOp::reifyResultShapes(OpBuilder &builder,
5529SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5533SmallVector<int64_t> PackOp::getStaticTiles() {
5537ArrayRef<int64_t> PackOp::getAllOuterDims() {
5538 ShapedType inputType = getSourceType();
5539 int64_t inputRank = inputType.getRank();
5540 return getDestType().getShape().take_front(inputRank);
5543SmallVector<int64_t> PackOp::getTiledOuterDims() {
5544 auto innerDimsPos = getInnerDimsPos();
5545 SmallVector<int64_t> outerDims(getAllOuterDims());
5546 SmallVector<int64_t> res;
5549 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5551 if (!outerDimPermInv.empty())
5555 for (
auto index : innerDimsPos)
5556 res.push_back(outerDims[index]);
5561bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5562 ArrayRef<int64_t> innerDimsPos,
5563 ArrayRef<int64_t> outputShape,
5564 ArrayRef<int64_t> outerDimsPerm,
5565 ArrayRef<OpFoldResult> innerTiles) {
5566 SmallVector<int64_t> outputTileSizes(
5567 outputShape.take_front(inputShape.size()));
5568 if (!outerDimsPerm.empty()) {
5569 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5570 "expected output and outer_dims_perm to have same size");
5574 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5575 if (ShapedType::isDynamic(inputShape[pos]))
5578 if (!constantTile) {
5579 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5580 (inputShape[pos] % outputTileSizes[pos] != 0))
5583 assert(*constantTile != 0 &&
"static tile size can't be zero");
5584 if (inputShape[pos] % (*constantTile) != 0) {
5592bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5593 ArrayRef<int64_t> innerDimsPos,
5594 ArrayRef<int64_t> outputShape,
5595 ArrayRef<int64_t> outerDimsPerm,
5596 ArrayRef<OpFoldResult> innerTiles) {
5597 SmallVector<int64_t> outputTileSizes(
5598 outputShape.take_front(inputShape.size()));
5599 if (!outerDimsPerm.empty()) {
5600 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5601 "expected output and outer_dims_perm to have same size");
5605 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5606 if (ShapedType::isDynamic(inputShape[pos]) ||
5607 ShapedType::isDynamic(outputTileSizes[pos]))
5612 assert(*constantTile != 0 &&
"static tile size can't be zero");
5613 if (inputShape[pos] % (*constantTile) != 0)
5619LogicalResult PackOp::verify() {
5626 auto paddingValue = getPaddingValue();
5630 << getSourceType().getElementType()
5631 <<
" but got: " << paddingValue.getType();
5634 if (!paddingValue &&
5635 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5636 getDestType().
getShape(), getOuterDimsPerm(),
5639 "invalid tile factor or output size provided. Only full tiles are "
5640 "supported when padding_value is not set");
5647static SmallVector<int64_t>
5650 for (
auto o : ofrs) {
5652 if (llvm::dyn_cast_if_present<Value>(o))
5653 result.push_back(ShapedType::kDynamic);
5665 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5666 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5668 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5669 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5672 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5673 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5677 if (!outerDimsPerm.empty())
5681 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5685SmallVector<OpFoldResult> PackOp::getResultShape(
5686 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5687 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5688 ArrayRef<int64_t> outerDimsPerm) {
5689 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5693 AffineExpr ceilDivExpr = s0.
ceilDiv(s1);
5694 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5696 builder, loc, ceilDivExpr,
5697 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5699 if (!outerDimsPerm.empty())
5701 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5703 SmallVector<int64_t> resultTypeShape =
5706 innerDimsPos, outerDimsPerm);
5712 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5713 if (ShapedType::isStatic(resultTypeShape[i]))
5722RankedTensorType PackOp::inferPackedTensorType(
5723 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5724 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5725 SmallVector<int64_t> resultShape = inferPackedShape(
5726 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5727 return RankedTensorType::get(resultShape, sourceType.getElementType());
5730MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5731 ArrayRef<int64_t> innerTileSizes,
5732 ArrayRef<int64_t> innerDimsPos,
5733 ArrayRef<int64_t> outerDimsPerm) {
5734 SmallVector<int64_t> resultShape = inferPackedShape(
5735 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5736 return MemRefType::get(resultShape, sourceType.getElementType());
5739Value PackOp::createDestinationTensor(OpBuilder &
b, Location loc, Value source,
5740 ArrayRef<OpFoldResult> innerTileSizes,
5741 ArrayRef<int64_t> innerDimsPos,
5742 ArrayRef<int64_t> outerDimsPerm) {
5743 AffineExpr dim0, dim1;
5745 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5750 SmallVector<OpFoldResult> mixedSizes;
5751 for (
auto [index, value] : llvm::enumerate(
5752 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5753 if (ShapedType::isDynamic(value))
5754 mixedSizes.push_back(
5755 tensor::DimOp::create(
b, loc, source, index).getResult());
5757 mixedSizes.push_back(
b.getIndexAttr(value));
5759 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5760 int64_t dimPos = std::get<0>(it);
5761 OpFoldResult tileSize = std::get<1>(it);
5762 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5764 if (!outerDimsPerm.empty())
5767 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5768 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5769 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5772PackOp PackOp::createTransposedClone(OpBuilder &
b, Location loc,
5773 ArrayRef<int64_t> innerPermutation,
5774 ArrayRef<int64_t> outerPermutation) {
5776 *
this, innerPermutation, outerPermutation);
5777 Value transposedDest =
5778 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5779 metadata.innerDimsPos, metadata.outerDimsPerm);
5780 return PackOp::create(
b, loc, getSource(), transposedDest,
5781 metadata.innerDimsPos, metadata.innerTiles,
5782 getPaddingValue(), metadata.outerDimsPerm);
5785template <
typename OpTy>
5790 if (op.hasPureTensorSemantics())
5793 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5794 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5797 if (&opOperand == &op.getSourceMutable()) {
5801 }
else if (&opOperand == &op.getDestMutable()) {
5812void PackOp::getEffects(
5818void UnPackOp::getEffects(
5825template <
typename OpTy>
5827 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5828 "applies to only pack or unpack operations");
5829 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5831 : op.getSourceType();
5833 for (
auto [dimDest,
tile] : llvm::zip(
5834 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5836 if (!constTileSize || ShapedType::isDynamic(dimDest))
5843 if (!hasPureTensorSemantics())
5845 if (getPaddingValue())
5860 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5862 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5874 auto packTiles = packOp.getMixedTiles();
5875 auto unPackTiles = unPackOp.getMixedTiles();
5876 if (packTiles.size() != unPackTiles.size())
5878 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5887 auto srcType = op.getSourceType();
5888 if (llvm::any_of(op.getInnerDimsPos(),
5889 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5891 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5893 return !PackOp::requirePaddingValue(
5894 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5895 op.getOuterDimsPerm(), op.getMixedTiles());
5902 bool changeNeeded =
false;
5903 srcShape.assign(packOp.getSourceType().getShape().begin(),
5904 packOp.getSourceType().getShape().end());
5905 destShape.assign(packOp.getDestType().getShape().begin(),
5906 packOp.getDestType().getShape().end());
5907 llvm::SmallSetVector<int64_t, 4> innerDims;
5908 innerDims.insert_range(packOp.getInnerDimsPos());
5910 if (!packOp.getOuterDimsPerm().empty())
5912 int srcRank = packOp.getSourceRank();
5913 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5914 if (innerDims.contains(i))
5918 if (!inverseOuterDimsPerm.empty())
5919 destPos = inverseOuterDimsPerm[srcPos];
5920 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5921 ShapedType::isDynamic(destShape[destPos])) {
5924 int64_t size = srcShape[srcPos];
5925 if (ShapedType::isDynamic(size))
5926 size = destShape[destPos];
5927 srcShape[srcPos] = size;
5928 destShape[destPos] = size;
5929 changeNeeded =
true;
5931 return changeNeeded;
5934LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5936 if (!packOp.hasPureTensorSemantics())
5940 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5941 if (unPackOp.getSourceType() == packOp.getDestType() &&
5942 !packOp.getPaddingValue() &&
5945 rewriter.
replaceOp(packOp, unPackOp.getSource());
5953 packOp.getPaddingValueMutable().clear();
5959 SmallVector<int64_t> srcShape, destShape;
5961 Location loc = packOp.getLoc();
5962 Value source = packOp.getSource();
5963 if (srcShape != packOp.getSourceType().getShape()) {
5964 auto newSrcType = packOp.getSourceType().clone(srcShape);
5966 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5968 Value dest = packOp.getDest();
5969 ShapedType originalResultType = packOp.getDestType();
5970 bool needUpdateDestType = (destShape != originalResultType.getShape());
5971 if (needUpdateDestType) {
5972 auto newDestType = packOp.getDestType().clone(destShape);
5974 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5977 packOp.getSourceMutable().assign(source);
5978 packOp.getDestMutable().assign(dest);
5979 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5982 if (needUpdateDestType) {
5984 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5985 packOp.getResult());
5994template <
typename PackOrUnpackOp>
5996 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5997 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5998 "Function meant for pack/unpack");
6003 int64_t numPackedDims = innerDimsPos.size();
6004 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
6005 if (orderedDims != innerDimsPos) {
6011 int64_t packedRank = packedTensorType.getRank();
6021 return llvm::all_of(
6022 llvm::seq<int64_t>(0, packedRank - numPackedDims),
6023 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
6026bool PackOp::isLikePad() {
6027 auto packedTensorType =
6028 llvm::cast<ShapedType>((*this)->getResultTypes().front());
6032::mlir::LogicalResult
6033PackOp::fold(FoldAdaptor adaptor,
6035 if (!hasPureTensorSemantics())
6037 std::optional<Attribute> paddingValue;
6038 if (
auto pad = adaptor.getPaddingValue())
6040 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6041 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6042 cast<TensorType>(getDestType()), paddingValue)) {
6043 results.push_back(reshapedSource);
6069 if (!op.hasPureTensorSemantics())
6090 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6091 op.getInnerDimsPos(), newMixedTileSizes,
6092 op.getPaddingValue(), op.getOuterDimsPerm());
6093 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6096 Value oldResult = op.getResult();
6097 Value newResult = newOp.getResult();
6100 ? tensor::CastOp::create(rewriter, op->getLoc(),
6101 oldResult.
getType(), newResult)
6114void UnPackOp::getAsmResultNames(
6116 if (!getResults().empty())
6117 setNameFn(getResult(),
"unpack");
6126 Type sourceType, destType, resultType;
6138 if (parser.parseInteger(value))
6140 outerDimsPermVec.push_back(value);
6150 SmallVector<int64_t> innerDimsPosVec;
6153 if (parser.parseInteger(value))
6155 innerDimsPosVec.push_back(value);
6167 for (
auto val : staticTilesAttr.
asArrayRef())
6168 staticTiles.push_back(val);
6185 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6188 "pack/unpack requires '->' and destination type");
6192 resultType = destType;
6198 if (!dynamicTiles.empty() &&
6203 result.addAttribute(
"static_inner_tiles",
6205 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6207 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6209 SmallVector<int32_t> segmentSizes = {
6210 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6211 result.addAttribute(
"operandSegmentSizes",
6215 result.addTypes(resultType);
6220void UnPackOp::print(OpAsmPrinter &p) {
6221 p <<
" " << getSource();
6223 if (!getOuterDimsPerm().empty()) {
6224 p <<
" outer_dims_perm = [";
6225 llvm::interleaveComma(getOuterDimsPerm(), p);
6229 p <<
" inner_dims_pos = [";
6230 llvm::interleaveComma(getInnerDimsPos(), p);
6233 p <<
" inner_tiles = ";
6236 p <<
" into " << getDest();
6239 {
"static_inner_tiles",
"inner_dims_pos",
6240 "outer_dims_perm",
"operandSegmentSizes"});
6242 p <<
" : " << getSource().getType();
6243 p <<
" -> " << getDest().getType();
6247UnPackOp::reifyResultShapes(OpBuilder &builder,
6256SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6260SmallVector<int64_t> UnPackOp::getStaticTiles() {
6264ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6265 ShapedType destType = getDestType();
6266 int64_t destRank = destType.getRank();
6267 return getSourceType().getShape().take_front(destRank);
6270SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6271 auto innerDimsPos = getInnerDimsPos();
6272 SmallVector<int64_t> outerDims(getAllOuterDims());
6273 SmallVector<int64_t> res;
6276 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6278 if (!outerDimPermInv.empty())
6282 for (
auto index : innerDimsPos)
6283 res.push_back(outerDims[index]);
6288LogicalResult UnPackOp::verify() {
6293 if (!hasPureTensorSemantics())
6302void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6303 Value dest, ArrayRef<int64_t> innerDimsPos,
6304 ArrayRef<OpFoldResult> innerTiles,
6305 ArrayRef<int64_t> outerDimsPerm) {
6306 assert(innerDimsPos.size() == innerTiles.size() &&
6307 "number of tile sizes specified must match the specified number of "
6308 "original dimensions to be tiled");
6309 SmallVector<int64_t> staticTileSizes;
6310 SmallVector<Value> dynamicTileSizes;
6312 build(builder, state, dest.
getType(), source, dest,
6313 outerDimsPerm.empty() ?
nullptr
6319Value UnPackOp::createDestinationTensor(OpBuilder &
b, Location loc,
6321 ArrayRef<OpFoldResult> innerTileSizes,
6322 ArrayRef<int64_t> innerDimsPos,
6323 ArrayRef<int64_t> outerDimsPerm) {
6324 AffineExpr sym0, sym1;
6326 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6330 SmallVector<OpFoldResult> mixedSizes;
6331 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6333 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6334 if (srcType.isDynamicDim(i))
6335 mixedSizes.push_back(
6336 tensor::DimOp::create(
b, loc, source, i).getResult());
6338 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6340 if (!outerDimsPerm.empty()) {
6345 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6346 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6348 auto elemType = srcType.getElementType();
6349 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6352UnPackOp UnPackOp::createTransposedClone(OpBuilder &
b, Location loc,
6353 Value transposedSource,
6354 ArrayRef<int64_t> innerPermutation,
6355 ArrayRef<int64_t> outerPermutation) {
6357 *
this, innerPermutation, outerPermutation);
6358 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6359 metadata.innerDimsPos, metadata.innerTiles,
6360 metadata.outerDimsPerm);
6367 bool changeNeeded =
false;
6368 srcShape.assign(op.getSourceType().getShape().begin(),
6369 op.getSourceType().getShape().end());
6370 destShape.assign(op.getDestType().getShape().begin(),
6371 op.getDestType().getShape().end());
6372 llvm::SmallSetVector<int64_t, 4> innerDims;
6373 innerDims.insert_range(op.getInnerDimsPos());
6375 if (!op.getOuterDimsPerm().empty())
6377 int destRank = op.getDestRank();
6378 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6379 if (innerDims.contains(i))
6383 if (!inverseOuterDimsPerm.empty())
6384 srcPos = inverseOuterDimsPerm[destPos];
6385 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6386 ShapedType::isDynamic(destShape[destPos])) {
6389 int64_t size = srcShape[srcPos];
6390 if (ShapedType::isDynamic(size))
6391 size = destShape[destPos];
6392 srcShape[srcPos] = size;
6393 destShape[destPos] = size;
6394 changeNeeded =
true;
6396 return changeNeeded;
6399LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6402 if (!unPackOp.hasPureTensorSemantics())
6406 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6407 if (packOp.getSourceType() != unPackOp.getDestType())
6409 if (packOp.getPaddingValue() ||
6413 rewriter.
replaceOp(unPackOp, packOp.getSource());
6417 if (
auto dstStyleOp =
6418 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6419 auto destValue = cast<OpResult>(unPackOp.getDest());
6420 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6422 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6426 if (unPackOp->hasOneUse()) {
6427 auto extractSliceUser =
6428 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6429 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6430 OpBuilder::InsertionGuard g(rewriter);
6432 auto newDest = tensor::ExtractSliceOp::create(
6433 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6434 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6435 extractSliceUser.getMixedStrides());
6437 unPackOp.setDpsInitOperand(0, newDest);
6438 unPackOp.getResult().setType(newDest.
getType());
6440 rewriter.
replaceOp(extractSliceUser, unPackOp);
6446 SmallVector<int64_t> srcShape, destShape;
6448 Location loc = unPackOp.getLoc();
6449 Value source = unPackOp.getSource();
6450 if (srcShape != unPackOp.getSourceType().getShape()) {
6451 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6452 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6453 unPackOp.getSource());
6455 Value dest = unPackOp.getDest();
6456 if (destShape != unPackOp.getDestType().getShape()) {
6457 auto newDestType = unPackOp.getDestType().clone(destShape);
6458 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6459 unPackOp.getDest());
6461 UnPackOp newOp = UnPackOp::create(
6462 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6463 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6465 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6472bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6474 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6479 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6480 SmallVector<int64_t> outerShapeWithoutTranspose =
6482 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(),
false);
6483 for (
auto [pos, tileSize] :
6484 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6485 areOuterDimsTiled[pos] =
true;
6486 if (unpackedTypeAfterFold.isDynamicDim(pos))
6488 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6490 if (ShapedType::isDynamic(tileSize))
6492 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6493 unpackedTypeAfterFold.getDimSize(pos);
6494 if (paddingSize >= tileSize)
6498 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6499 if (areOuterDimsTiled[pos])
6501 int64_t dim = outerShapeWithoutTranspose[pos];
6502 if (ShapedType::isDynamic(dim))
6504 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6510bool UnPackOp::isLikeUnPad() {
6511 ShapedType packedTensorType = getSourceType();
6515::mlir::LogicalResult
6516UnPackOp::fold(FoldAdaptor adaptor,
6517 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6519 if (!hasPureTensorSemantics())
6522 if (OpFoldResult reshapedSource = reshapeConstantSource(
6523 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6524 cast<TensorType>(getResult().
getType()))) {
6525 results.push_back(reshapedSource);
6551 if (!op.hasPureTensorSemantics())
6560 Value sourceTensor = newOperands[0];
6564 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6570 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6571 newOperands[1], op.getInnerDimsPos(),
6572 newMixedTileSizes, op.getOuterDimsPerm());
6573 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6576 Value oldResult = op.getResult();
6577 Value newResult = newOp.getResult();
6580 ? tensor::CastOp::create(rewriter, op->getLoc(),
6581 oldResult.
getType(), newResult)
6595 utils::IteratorType::reduction, utils::IteratorType::parallel,
6596 utils::IteratorType::parallel, utils::IteratorType::reduction};
6599SmallVector<AffineMap>
6600BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6601 AffineExpr d0, d1, d2, d3;
6602 SmallVector<AffineMap> indexingMaps;
6604 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6605 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6607 return indexingMaps;
6610bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6611 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6614 if (maps.size() != 3)
6619 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6620 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6621 (*positions)[2] == SmallVector<int64_t>{1, 2};
6623unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6625std::string BatchReduceMatmulOp::getLibraryCallName() {
6631bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6632 SmallVector<AffineMap, 3> defaultMaps =
6634 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6635 return defaultMaps != explicitMaps;
6645bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6648 "Expected less than 3 result dim expr.");
6649 bool isValid =
false;
6650 enum Indices { batchPos, mPos, nPos, kPos };
6652 AffineExpr expr = bcastMap.
getResult(0);
6655 AffineExpr expr0 = bcastMap.
getResult(0);
6656 AffineExpr expr1 = bcastMap.
getResult(1);
6661 : ((expr0.isFunctionOfDim(batchPos) &&
6662 expr1.isFunctionOfDim(kPos)) ||
6663 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6668void BatchReduceMatmulOp::regionBuilder(
6669 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
6672 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6677 "BatchReduceMatmulOp regionBuilder expects 3 args");
6678 RegionBuilderHelper helper(
b, block);
6679 SmallVector<Value> yields;
6683 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6685 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6687 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6688 if (!castValA || !castValB || !mulVal)
6691 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6694 yields.push_back(addVal);
6695 helper.yieldOutputs(yields);
6698ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6699 OperationState &
result) {
6700 SmallVector<Attribute, 3> indexingMapsAttr;
6711 if (!isa<AffineMapAttr>(mapAttr)) {
6713 "expected affine map attribute");
6715 indexingMapsAttr.push_back(mapAttr);
6725 if (indexingMapsAttr.empty()) {
6726 indexingMapsAttr = llvm::map_to_vector(
6727 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6728 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6730 result.addAttribute(
"indexing_maps",
6732 return ::parseNamedStructuredOp(parser,
result,
6733 BatchReduceMatmulOp::getNumRegionArgs(),
6734 BatchReduceMatmulOp::getRegionBuilder());
6737void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6738 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6739 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6740 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
6742 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6743 p <<
" indexing_maps = [";
6744 llvm::interleaveComma(getIndexingMaps(), p,
6749 SmallVector<StringRef, 3> elidedAttrs = {
6750 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6756LogicalResult BatchReduceMatmulOp::verify() {
6759 if (!hasUserDefinedMaps())
6762 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6768LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6769 SmallVectorImpl<OpFoldResult> &) {
6772void BatchReduceMatmulOp::getEffects(
6773 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6775 if (hasPureTensorSemantics())
6791void LinalgDialect::getCanonicalizationPatterns(
6800 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static std::optional< TypedAttr > getScalarConstantAttrFromDenseSplat(Value input)
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, ArrayRef< OpFoldResult > mixedTiles)
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Rewrite a broadcast of a dense splat constant into a dense splat constant of the broadcast output sha...
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Rewrite a transpose of a dense splat constant into a dense splat constant of the transposed output sh...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override