40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/SmallVectorExtras.h"
45#include "llvm/ADT/StringSet.h"
46#include "llvm/ADT/TypeSwitch.h"
47#include "llvm/Support/FormatVariadic.h"
48#include "llvm/Support/InterleavedRange.h"
49#include "llvm/Support/LogicalResult.h"
50#include "llvm/Support/MathExtras.h"
51#include "llvm/Support/raw_ostream.h"
61 auto type = cast<ShapedType>(v.
getType());
62 if (!type.isDynamicDim(dim))
67 .Case([&](RankedTensorType t) ->
Value {
68 return tensor::DimOp::create(builder, loc, v, dim);
70 .Case([&](MemRefType t) ->
Value {
71 return memref::DimOp::create(builder, loc, v, dim);
82 .Case([&](RankedTensorType t) ->
Operation * {
83 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
86 .Case([&](MemRefType type) ->
Operation * {
87 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
99 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
100 return b.createOrFold<memref::DimOp>(loc, source, dim);
101 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
102 return b.createOrFold<tensor::DimOp>(loc, source, dim);
103 llvm_unreachable(
"Expected MemRefType or TensorType");
108 auto shapedType = llvm::cast<ShapedType>(source.
getType());
109 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
111 return b.getIndexAttr(shapedType.getDimSize(dim));
134 for (
auto containers : {inputTypes, outputTypes}) {
135 for (
auto t : containers) {
147 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
163 std::optional<TypeRange> resultTensorTypes,
170 if (!resultTensorTypes)
171 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
172 llvm::IsaPred<RankedTensorType>);
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
192 std::optional<TypeRange> resultTensorTypes,
199 return attr.
getName() ==
"indexing_maps";
202 indexingMapsAttrVal = llvm::map_to_vector(
205 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
208 attributes, regionBuilder);
212 std::optional<TypeRange> resultTensorTypes,
219 return attr.
getName() ==
"indexing_maps";
222 indexingMapsAttrVal = llvm::map_to_vector(
225 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
228 attributes, regionBuilder);
232 std::optional<TypeRange> resultTensorTypes,
239 indexingMapsAttrVal =
241 return AffineMapAttr::get(map);
243 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
245 attributes, regionBuilder);
254 bool addOperandSegmentSizes =
true) {
255 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
284 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
286 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
290 if (addOperandSegmentSizes) {
297 if (
result.propertiesAttr) {
299 attrs.
append(
"operandSegmentSizes",
301 {static_cast<int32_t>(inputsOperands.size()),
302 static_cast<int32_t>(outputsOperands.size())}));
305 result.addAttribute(
"operandSegmentSizes",
307 {static_cast<int32_t>(inputsOperands.size()),
308 static_cast<int32_t>(outputsOperands.size())}));
311 if (!
result.propertiesAttr) {
312 std::optional<RegisteredOperationName> info =
313 result.name.getRegisteredInfo();
315 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
316 return parser.emitError(attrsLoc)
317 <<
"'" << result.name.getStringRef() <<
"' op ";
328 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
329 if (!outputs.empty())
330 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
341 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
344 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
345 "region expects {0} args, got {1}",
346 numRegionArgs, inputTypes.size() + outputTypes.size()));
352 opBuilder, region, inputTypes, outputTypes, attrs,
371 unsigned numRegionArgs,
388 result.addTypes(outputTensorsTypes);
390 std::unique_ptr<Region> region = std::make_unique<Region>();
392 outputTypes,
result.attributes.getAttrs(),
395 result.addRegion(std::move(region));
402 if (resultTypes.empty())
447class RegionBuilderHelper {
449 RegionBuilderHelper(OpBuilder &builder,
Block &block)
450 : builder(builder), block(block) {}
453 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
455 if (!isFloatingPoint(arg)) {
457 emitError() <<
"unsupported non numeric type";
460 llvm_unreachable(
"unsupported non numeric type");
462 OpBuilder::InsertionGuard g(builder);
463 builder.setInsertionPointToEnd(&block);
466 return math::ExpOp::create(builder, arg.
getLoc(), arg);
468 return math::LogOp::create(builder, arg.
getLoc(), arg);
470 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
472 return math::CeilOp::create(builder, arg.
getLoc(), arg);
474 return math::FloorOp::create(builder, arg.
getLoc(), arg);
476 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
477 case UnaryFn::reciprocal: {
478 Attribute oneAttr = builder.getOneAttr(arg.
getType());
479 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
480 ::cast<TypedAttr>(oneAttr));
481 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
484 return math::RoundOp::create(builder, arg.
getLoc(), arg);
486 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
488 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
489 case UnaryFn::square:
490 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
492 return math::TanhOp::create(builder, arg.
getLoc(), arg);
494 return math::ErfOp::create(builder, arg.
getLoc(), arg);
497 emitError() <<
"unsupported unary function";
500 llvm_unreachable(
"unsupported unary function");
507 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
509 bool allComplex = isComplex(arg0) && isComplex(arg1);
510 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
511 bool allInteger = isInteger(arg0) && isInteger(arg1);
514 if (!allComplex && !allFloatingPoint && !allInteger) {
517 <<
"Cannot build binary Linalg operation: expects allComplex, "
518 "allFloatingPoint, or allInteger, got "
522 llvm_unreachable(
"unsupported non numeric type");
524 OpBuilder::InsertionGuard g(builder);
525 builder.setInsertionPointToEnd(&block);
529 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 if (allFloatingPoint)
531 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
533 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
534 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
537 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
538 if (allFloatingPoint)
539 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
542 emitError() <<
"unsupported operation: sub with bools";
545 llvm_unreachable(
"unsupported operation: sub with bools");
547 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
550 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 if (allFloatingPoint)
552 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
554 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
555 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
558 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
559 if (allFloatingPoint)
560 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
563 emitError() <<
"unsupported operation: div with bools";
566 llvm_unreachable(
"unsupported operation: div with bools");
568 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
569 case BinaryFn::div_unsigned:
570 if (!allInteger || allBool) {
572 emitError() <<
"unsupported operation: unsigned div not on uint";
575 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
577 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
578 case BinaryFn::max_signed:
580 if (allFloatingPoint)
581 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
582 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
583 case BinaryFn::min_signed:
585 if (allFloatingPoint)
586 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
587 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
588 case BinaryFn::max_unsigned:
590 if (!allInteger || allBool) {
592 emitError() <<
"unsupported operation: unsigned max not on uint";
595 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
597 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
598 case BinaryFn::min_unsigned:
600 if (!allInteger || allBool) {
602 emitError() <<
"unsupported operation: unsigned min not on uint";
605 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
607 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
609 assert(allFloatingPoint);
610 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
613 emitError() <<
"unsupported binary function";
616 llvm_unreachable(
"unsupported binary function");
620 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
622 OpBuilder::InsertionGuard g(builder);
623 builder.setInsertionPointToEnd(&block);
625 case TernaryFn::select:
626 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
629 emitError() <<
"unsupported ternary function";
632 llvm_unreachable(
"unsupported ternary function");
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
639 case TypeFn::cast_signed:
640 return cast(toType, operand,
false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand,
true);
645 emitError() <<
"unsupported type conversion function";
648 llvm_unreachable(
"unsupported type conversion function");
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
658 Value constant(
const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr =
parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
673 Type getIntegerType(
unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
677 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
685 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.
getLoc();
689 if (isa<UnknownLoc>(loc)) {
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.
getType());
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.
getType());
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.
getType());
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter)
const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
727 if (copyOp.hasPureBufferSemantics())
730 rewriter.
replaceOp(copyOp, copyOp.getInputs());
740 results.
add<EraseSelfCopy>(context);
753template <
typename TensorReshapeOp>
754struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter)
const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
783struct FoldFillWithPad final :
public OpRewritePattern<tensor::PadOp> {
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter)
const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
801 padOp,
"failed to reify tensor.pad op result shape");
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
807 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
810 if (
replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
822struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter)
const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
836 Value firstDest = insertOp.getDest();
837 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
843 bool disjoint =
false;
844 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
867 firstDest = prevOp.getDest();
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
885 Location loc = insertOp.getLoc();
888 AffineExpr sym0, sym1;
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
908 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
920struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter)
const override {
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
933 Value extractedScalar = fillOp.getInputs()[0];
936 rewriter.
replaceOp(extractOp, extractedScalar);
944static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
950 if (
auto paddingValue = packOp.getPaddingValue())
954 Value packOpDest = packOp.getDest();
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
963struct FoldFillWithPack :
public OpRewritePattern<linalg::PackOp> {
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter)
const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
973 rewriter.
replaceOp(packOp, fillOp.value().result());
979struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter)
const override {
984 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
987 copyOp.getOutputs());
990 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
992 fillOp.getOutputs());
1000struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter)
const override {
1005 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1017struct FoldConcatsOfFill :
public OpRewritePattern<tensor::ConcatOp> {
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter)
const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1032 OpFoldResult firstFillVal =
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1038 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1044 OpFoldResult fillVal =
1046 if (fillVal != firstFillVal)
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1055 concatOp,
"not all operands are defined by a compatible fill op");
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1061 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1068void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1069 MLIRContext *context) {
1070 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1087 for (
ValueRange container : {inputs, outputs}) {
1088 for (
Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1092 blockArgLocs.push_back(v.getLoc());
1098 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1102void GenericOp::getAsmBlockArgumentNames(Region ®ion,
1104 for (Value v : getRegionInputArgs())
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v,
"out");
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1121 inputs, outputs, bodyBuild);
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder,
result, resultTensorTypes, inputs, outputs,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1139 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1143void GenericOp::build(
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1149 ArrayRef<NamedAttribute> attributes) {
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1154void GenericOp::build(
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1162 "", bodyBuild, attributes);
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1174 "", bodyBuild, attributes);
1177void GenericOp::print(OpAsmPrinter &p) {
1181 auto genericAttrNames = linalgTraitAttrNames();
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (
auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1195 SmallVector<Attribute> iteratorTypeNames = llvm::map_to_vector(
1196 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1197 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1200 genericAttrs.emplace_back(
1201 getIteratorTypesAttrName(),
1202 ArrayAttr::get(
getContext(), iteratorTypeNames));
1203 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1204 genericAttrs.push_back(attr);
1207 if (!genericAttrs.empty()) {
1208 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1209 p << genericDictAttr;
1215 genericAttrNames.push_back(
"operandSegmentSizes");
1216 genericAttrNamesSet.insert(genericAttrNames.back());
1218 bool hasExtraAttrs =
false;
1219 for (NamedAttribute n : (*this)->getAttrs()) {
1220 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1223 if (hasExtraAttrs) {
1230 if (!getRegion().empty()) {
1239ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1240 DictionaryAttr dictAttr;
1248 result.attributes.assign(dictAttr.getValue().begin(),
1249 dictAttr.getValue().end());
1255 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1256 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1257 if (!iteratorTypes) {
1258 return parser.
emitError(attributeLocation)
1259 <<
"expected " << getIteratorTypesAttrName(
result.name)
1260 <<
" array attribute";
1263 SmallVector<Attribute> iteratorTypeAttrs;
1265 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1266 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1267 if (!maybeIteratorType.has_value())
1269 <<
"unexpected iterator_type (" << s <<
")";
1271 iteratorTypeAttrs.push_back(
1272 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1274 result.attributes.set(getIteratorTypesAttrName(
result.name),
1278 SmallVector<Type, 1> inputTypes, outputTypes;
1288 std::unique_ptr<Region> region = std::make_unique<Region>();
1291 result.addRegion(std::move(region));
1297 SmallVector<Type, 1> outputTensorsTypes;
1300 result.addTypes(outputTensorsTypes);
1308 LinalgOp linalgOp) {
1309 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1310 if (!llvm::isa<MemRefType>(operand.
getType()))
1312 effects.emplace_back(
1317 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1318 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1320 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1331void GenericOp::getEffects(
1332 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1341 if (!linalgOp.hasPureTensorSemantics())
1359template <
typename OpTy>
1360struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1361 using OpRewritePattern<OpTy>::OpRewritePattern;
1363 LogicalResult matchAndRewrite(OpTy linalgOp,
1364 PatternRewriter &rewriter)
const override {
1366 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1371 Block &body = linalgOp->getRegion(0).front();
1372 if (!llvm::hasSingleElement(body))
1374 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1379 if (linalgOp.hasPureBufferSemantics()) {
1380 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1381 linalgOp.getDpsInputOperand(0)->get() !=
1382 linalgOp.getDpsInitOperand(0)->get()) {
1384 linalgOp,
"expected single input and output to be the same value");
1387 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1388 if (!yieldArg || yieldArg.getOwner() != &body) {
1390 "cannot fold fill-like op");
1397 if (!linalgOp.hasPureTensorSemantics()) {
1399 linalgOp,
"mixed semantics is not supported yet");
1404 SmallVector<Value> returnedArgs;
1405 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1406 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1407 if (!yieldArg || yieldArg.getOwner() != &body)
1409 unsigned argumentNumber = yieldArg.getArgNumber();
1410 Value returnedArg = linalgOp->getOperand(argumentNumber);
1411 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1414 Type returnType = returnedArg.
getType();
1415 if (returnType != resultType) {
1420 returnedArg = sparse_tensor::ConvertOp::create(
1421 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1423 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1426 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1427 resultType, returnedArg);
1430 returnedArgs.push_back(returnedArg);
1433 if (returnedArgs.size() != linalgOp->getNumResults())
1435 rewriter.
replaceOp(linalgOp, returnedArgs);
1442void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1443 MLIRContext *context) {
1444 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1447LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1466 for (
Type outputType : outputTypes) {
1467 if (llvm::isa<RankedTensorType>(outputType))
1468 result.addTypes(outputType);
1472 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1481void MapOp::getAsmBlockArgumentNames(Region ®ion,
1483 for (Value v : getRegionInputArgs())
1485 for (Value v : getRegionOutputArgs())
1486 setNameFn(v,
"init");
1489void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1490 if (!getResults().empty())
1491 setNameFn(getResults().front(),
"mapped");
1497 ArrayRef<NamedAttribute> attributes) {
1499 result.addAttributes(attributes);
1502 Type initType = init.
getType();
1503 if (llvm::isa<RankedTensorType>(initType))
1504 result.addTypes(initType);
1508 inputs, {init}, bodyBuild);
1515 bool initFirst =
false,
bool mapInit =
true) {
1519 b.setInsertionPointToStart(&block);
1520 for (
auto &operand : operands) {
1522 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1530 payloadOpOperands.push_back(block.
getArguments().back());
1531 for (
const auto &arg : block.
getArguments().drop_back())
1532 payloadOpOperands.push_back(arg);
1541 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1547ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1548 std::optional<OperationName> payloadOpName;
1549 NamedAttrList payloadOpAttrs;
1552 if (
failed(operationName))
1556 payloadOpName = operationName.value();
1564 if (payloadOpName.has_value()) {
1565 if (!
result.operands.empty())
1567 payloadOpAttrs, ArrayRef(
result.operands),
false,
1572 SmallVector<OpAsmParser::Argument> regionArgs;
1577 Region *body =
result.addRegion();
1585 bool mapInit =
true) {
1587 if (initFirst && !mapInit)
1611 for (
const auto &[operand, bbArg] :
1613 if (bbArg != operand)
1617 for (
const auto &[operand, bbArg] :
1620 if (bbArg != operand)
1627 return yieldOp.getNumOperands() == 1 &&
1628 yieldOp.getOperand(0).getDefiningOp() &&
1629 yieldOp.getOperand(0).getDefiningOp() == &payload;
1634 std::string attrToElide;
1636 for (
const auto &attr : payloadOp->
getAttrs()) {
1638 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1639 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1640 attrToElide = attr.getName().str();
1641 elidedAttrs.push_back(attrToElide);
1649void MapOp::print(OpAsmPrinter &p) {
1650 Block *mapper = getBody();
1660 if (!useShortForm) {
1666 [&](
auto arg) { p.printRegionArgument(arg); });
1674LogicalResult MapOp::verify() {
1675 auto *bodyBlock = getBody();
1676 auto blockArgs = bodyBlock->getArguments();
1680 if (getInputs().size() + 1 != blockArgs.size())
1681 return emitOpError() <<
"expects number of operands to match the arity of "
1683 << getInputs().size() + 1 <<
" and "
1684 << blockArgs.size();
1687 for (
const auto &[bbArgType, inputArg] :
1688 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1689 auto inputElemType =
1690 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1691 if (bbArgType != inputElemType) {
1692 return emitOpError() <<
"expected element type of input " << inputElemType
1693 <<
" to match bbArg type " << bbArgType;
1698 auto outputShape = getInit().getType().getShape();
1699 for (Type inputArgType :
TypeRange{getInputs()}) {
1700 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1701 if (inputElemShape != outputShape) {
1702 return emitOpError() <<
"expected shape of input (" << inputElemShape
1703 <<
") to match shape of output (" << outputShape
1711SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1712 int64_t rank = getInit().getType().getRank();
1713 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1718 int64_t rank = getInit().getType().getRank();
1719 int64_t numIndexingMaps = getOperands().size();
1724void MapOp::getEffects(
1725 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1738void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1740 for (Value v : getRegionInputArgs())
1742 for (Value v : getRegionOutputArgs())
1743 setNameFn(v,
"init");
1746void ReduceOp::getAsmResultNames(
1748 if (!getResults().empty())
1749 setNameFn(getResults().front(),
"reduced");
1752void ReduceOp::build(
1754 ValueRange inits, ArrayRef<int64_t> dimensions,
1756 ArrayRef<NamedAttribute> attributes) {
1758 result.addAttributes(attributes);
1761 for (Value init : inits) {
1762 Type initType = init.
getType();
1763 if (llvm::isa<RankedTensorType>(initType))
1764 result.addTypes(initType);
1769 inputs, inits, bodyBuild);
1772SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1774 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1775 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1776 utils::IteratorType::parallel);
1777 for (int64_t reductionDim : getDimensions())
1778 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1779 return iteratorTypes;
1784 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1785 SmallVector<AffineMap> affineMaps(
1788 AffineMap resultMap =
1791 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1792 affineMaps.push_back(resultMap);
1793 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1796void ReduceOp::getEffects(
1797 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1808 StringRef attributeName) {
1816ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1817 std::optional<OperationName> payloadOpName;
1818 NamedAttrList payloadOpAttrs;
1821 if (
failed(operationName))
1825 payloadOpName = operationName.value();
1831 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1836 if (payloadOpName.has_value()) {
1838 ArrayRef(
result.operands),
true);
1840 SmallVector<OpAsmParser::Argument> regionArgs;
1846 Region *body =
result.addRegion();
1856 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1859void ReduceOp::print(OpAsmPrinter &p) {
1860 Block *mapper = getBody();
1869 if (!useShortForm) {
1875 [&](
auto arg) { p.printRegionArgument(arg); });
1883LogicalResult ReduceOp::verify() {
1884 ArrayRef<int64_t> dimensionsRef = getDimensions();
1891 if (getInputs().size() !=
static_cast<size_t>(getNumDpsInputs()))
1893 <<
"expected equal number of inputs and outputs (required by "
1894 "SameVariadicOperandSize), got "
1895 << getNumDpsInputs() <<
" input(s) and " << getNumDpsInits()
1898 if (getInputs().empty())
1899 return emitOpError() <<
"expected at least one input";
1900 if (getInits().empty())
1901 return emitOpError() <<
"expected at least one output";
1903 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1906 return emitOpError() <<
"expects all inputs to have the same shapes. "
1907 "Shape at input-index "
1909 <<
" is not equal to the shape at input-index 0.";
1912 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1915 return emitOpError() <<
"expects all outputs to have the same shapes. "
1916 "Shape at output-index "
1918 <<
" is not equal to the shape at output-index 0.";
1921 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1922 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1925 for (int64_t dimension : dimensionsRef) {
1926 if (dimension < 0 || dimension >= inputType.getRank()) {
1928 <<
"dimensions for reduction should be in the range [0, "
1929 << inputType.getRank() - 1 <<
"].";
1931 dimensionsToReduce.insert(dimension);
1934 auto inputDims = inputType.getShape();
1935 auto initDims = initType.getShape();
1938 SmallVector<int64_t> reducedInputDims;
1939 for (
const auto &en : llvm::enumerate(inputDims)) {
1940 if (!dimensionsToReduce.count(en.index()))
1941 reducedInputDims.push_back(en.value());
1944 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1945 return emitOpError() <<
"number of dimensions after reduction "
1946 << reducedInputDims.size()
1947 <<
" doesn't match the init rank "
1948 << initType.getRank();
1951 if (reducedInputDims != initDims)
1952 return emitOpError() <<
"init dimensions [" << initDims
1953 <<
"] doesn't match input dimensions after reduction ["
1954 << reducedInputDims <<
"]";
1956 Block *block = getBody();
1959 <<
"mismatching number of operands and block arguments";
1962 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1963 Type inputElementType =
1964 llvm::cast<ShapedType>(input.getType()).getElementType();
1965 if (inputElementType != bbArg.getType())
1967 <<
"input element type " << inputElementType
1968 <<
" does not match corresponding block argument type "
1973 for (
auto [output, bbArg] : llvm::zip(
1974 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1975 auto outputElementType =
1976 llvm::cast<ShapedType>(output.getType()).getElementType();
1977 if (outputElementType != bbArg.getType())
1979 <<
"output element type " << outputElementType
1980 <<
" does not match corresponding block argument type "
1996 linalg::YieldOp::create(
b, loc, args[0]);
2000void TransposeOp::build(::mlir::OpBuilder &builder,
2001 ::mlir::OperationState &
result, Value input, Value init,
2003 ArrayRef<NamedAttribute> attributes) {
2004 result.addOperands(input);
2005 result.addOperands(init);
2006 result.addAttribute(getPermutationAttrName(
result.name), permutation);
2007 result.addAttributes(attributes);
2010 Type initType = init.
getType();
2011 if (llvm::isa<RankedTensorType>(initType))
2012 result.addTypes(initType);
2018void TransposeOp::build(::mlir::OpBuilder &builder,
2019 ::mlir::OperationState &
result, Value input, Value init,
2020 ArrayRef<int64_t> permutation,
2021 ArrayRef<NamedAttribute> attributes) {
2026ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2028 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2040void TransposeOp::getAsmResultNames(
2042 if (!getResults().empty())
2043 setNameFn(getResults().front(),
"transposed");
2046void TransposeOp::print(OpAsmPrinter &p) {
2052LogicalResult TransposeOp::verify() {
2053 ArrayRef<int64_t> permutationRef = getPermutation();
2058 auto inputType = getInput().getType();
2059 auto initType = getInit().getType();
2061 int64_t rank = inputType.getRank();
2067 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2068 return emitOpError() <<
"size of permutation " << permutationRef.size()
2069 <<
" does not match the argument rank " << rank;
2071 auto inputDims = inputType.getShape();
2072 auto initDims = initType.getShape();
2074 for (int64_t i = 0; i < rank; ++i) {
2075 int64_t inputDim = inputDims[permutationRef[i]];
2076 int64_t initDim = initDims[i];
2078 if (inputDim != initDim) {
2079 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2080 <<
" doesn't match dim(input, permutation[" << i
2081 <<
"]) = " << inputDim;
2088SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2089 int64_t rank = getInit().getType().getRank();
2090 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2093ArrayAttr TransposeOp::getIndexingMaps() {
2095 int64_t rank = getInit().getType().getRank();
2098 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2102void TransposeOp::getEffects(
2103 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2112LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2113 SmallVectorImpl<OpFoldResult> &
result) {
2115 if (!isa<TensorType>(getInput().
getType()))
2119 if (getPermutation().empty()) {
2120 result.push_back(getInput());
2125 result.push_back(getInput());
2138 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2139 if (!defTransposeOp)
2144 foldedPerms.reserve(perms.size());
2146 foldedPerms.push_back(defPerms[perm]);
2149 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2163 Value input = transposeOp.getInput();
2164 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2175 unsigned dimensionSize = dimensions.size();
2176 for (
unsigned i = 0; i < dimensionSize; ++i)
2177 resultDimensions.push_back(invertPerm[dimensions[i]]);
2180 Value broadcastInput = broadcastOp.getInput();
2181 Location loc = transposeOp.getLoc();
2184 auto broadcastInputTy =
2185 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2186 unsigned inputRank = broadcastInputTy.getRank();
2187 for (
unsigned i = 0; i < inputRank; ++i) {
2188 if (broadcastInputTy.isDynamicDim(i)) {
2189 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2192 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2193 broadcastInputTy.getDimSize(i)));
2198 Value transposeInit = tensor::EmptyOp::create(
2199 rewriter, transposeOp.getLoc(), transposeResultShapes,
2200 broadcastInputTy.getElementType());
2203 Value transposeResult =
2204 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2205 transposeInit, resultPerms)
2208 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2213void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2214 MLIRContext *context) {
2215 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2222void BroadcastOp::build(::mlir::OpBuilder &builder,
2223 ::mlir::OperationState &
result, Value input, Value init,
2225 ArrayRef<NamedAttribute> attributes) {
2226 result.addOperands(input);
2227 result.addOperands(init);
2228 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2229 result.addAttributes(attributes);
2232 Type initType = init.
getType();
2233 if (llvm::isa<RankedTensorType>(initType))
2234 result.addTypes(initType);
2240void BroadcastOp::build(::mlir::OpBuilder &builder,
2241 ::mlir::OperationState &
result, Value input, Value init,
2242 ArrayRef<int64_t> dimensions,
2243 ArrayRef<NamedAttribute> attributes) {
2248ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2250 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2262void BroadcastOp::getAsmResultNames(
2264 if (!getResults().empty())
2265 setNameFn(getResults().front(),
"broadcasted");
2268void BroadcastOp::print(OpAsmPrinter &p) {
2274LogicalResult BroadcastOp::verify() {
2275 ArrayRef<int64_t> dimensionsRef = getDimensions();
2277 auto inputType = getInput().getType();
2278 auto initType = getInit().getType();
2280 int64_t inputRank = inputType.getRank();
2281 int64_t initRank = initType.getRank();
2283 auto inputShape = inputType.getShape();
2284 auto initShape = initType.getShape();
2286 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2287 return emitOpError() <<
"input rank plus added dimensions does not "
2288 "match init rank. input rank: "
2290 <<
", dimensions size: " << dimensionsRef.size()
2291 <<
", init rank: " << initRank;
2293 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2294 if (dim < 0 || dim >= initRank)
2296 <<
" is out of range. expected range: [0, "
2297 << initRank - 1 <<
"], got: " << dim;
2301 SmallVector<int64_t> dimMap;
2302 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2303 if (!llvm::is_contained(dimensionsRef, dim))
2304 dimMap.push_back(dim);
2307 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2310 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2311 return emitOpError() <<
"input dim " << inputDimIdx
2312 <<
" should match init dim " << initDimIdx
2313 <<
". input: " << inputShape[inputDimIdx]
2314 <<
", init: " << initShape[initDimIdx];
2320SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2321 int64_t rank = getInit().getType().getRank();
2322 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2325ArrayAttr BroadcastOp::getIndexingMaps() {
2327 int64_t rank = getInit().getType().getRank();
2333void BroadcastOp::getEffects(
2334 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2349 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2350 if (!defBroadcastOp)
2355 Value init = broadcastOp.getInit();
2359 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2360 if (!llvm::is_contained(dimensions, dim))
2361 dimMap.push_back(dim);
2363 for (
auto dim : defDimensions)
2364 foldedDims.push_back(dimMap[dim]);
2366 llvm::sort(foldedDims);
2368 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2373void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2374 MLIRContext *context) {
2375 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2382void linalg::YieldOp::print(OpAsmPrinter &p) {
2383 if (getNumOperands() > 0)
2384 p <<
' ' << getOperands();
2386 if (getNumOperands() > 0)
2387 p <<
" : " << getOperandTypes();
2390ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2391 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2392 SmallVector<Type, 2> types;
2402static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2403 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2404 return op.emitOpError(
"expected number of yield values (")
2405 << op.getNumOperands()
2406 <<
") to match the number of inits / outs operands of the enclosing "
2407 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2409 for (
OpOperand &opOperand : op->getOpOperands()) {
2411 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2413 if (isa<MemRefType, RankedTensorType>(elementType))
2415 if (opOperand.get().getType() != elementType)
2416 return op.emitOpError(
"type of yield operand ")
2417 << (opOperand.getOperandNumber() + 1) <<
" ("
2418 << opOperand.get().getType() <<
") doesn't match "
2419 <<
"the element type of the enclosing linalg.generic op ("
2420 << elementType <<
")";
2425LogicalResult linalg::YieldOp::verify() {
2426 auto *parentOp = (*this)->getParentOp();
2427 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2428 return emitOpError(
"expected single non-empty parent region");
2430 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2433 return emitOpError(
"expected parent op with LinalgOp interface");
2440LogicalResult IndexOp::verify() {
2441 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2443 return emitOpError(
"expected parent op with LinalgOp interface");
2444 if (linalgOp.getNumLoops() <= getDim())
2446 << getDim() <<
") to be lower than the number of loops ("
2447 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2451OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2452 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2457 return OpFoldResult{};
2460 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2461 uint64_t dim = getDim();
2462 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2463 if (loopBounds[dim] == 1)
2464 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2466 return OpFoldResult{};
2471#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2473#define GET_OP_CLASSES
2474#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2476#define GET_OP_CLASSES
2477#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2478#define GET_OP_CLASSES
2479#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2496 for (
unsigned i = 0; i < num; ++i)
2503 auto rangeA = llvm::make_range(a.begin(), a.end());
2504 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2505 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2506 return llvm::to_vector<4>(concatRanges);
2510 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2512 for (
auto size :
memref.getShape())
2519 if (
auto as =
memref.getMemorySpace()) {
2520 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2521 ss <<
"as" << attr.getInt();
2527 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2530 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2543 assert(isa<LinalgOp>(op));
2545 std::string fun =
"";
2547 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2548 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2549 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2550 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2554 llvm::replace(name,
'.',
'_');
2555 llvm::raw_string_ostream ss(name);
2559 return std::string();
2574 LogicalResult matchAndRewrite(LinalgOp op,
2576 for (
OpOperand &opOperand : op->getOpOperands()) {
2580 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2583 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2594struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2595 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2597 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2598 PatternRewriter &rewriter)
const override {
2602 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2609 if (castOp->getBlock() != linalgOp->getBlock())
2612 OpBuilder::InsertionGuard guard(rewriter);
2615 Location loc = linalgOp.getLoc();
2616 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2619 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2625 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2627 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2628 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2629 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2630 linalgOp.getDpsInits().end());
2631 outputOperands[resultNumber] = newOperand;
2632 newOperands.append(outputOperands.begin(), outputOperands.end());
2634 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2635 linalgOp->result_type_end());
2636 resultTypes[resultNumber] = resultType;
2637 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2640 Value castBack = tensor::CastOp::create(
2644 results[resultNumber] = castBack;
2653static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2654 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2655 for (OpOperand &opOperand : operands) {
2656 if (linalgOp.isScalar(&opOperand))
2658 Value src = opOperand.get();
2659 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2660 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2666 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2668 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2669 Value castSource = castOp.getSource();
2670 auto castSourceType =
2671 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2672 if (castSourceType && castSourceType.hasStaticShape())
2673 sourceShape = castSourceType.getShape();
2679 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2680 if (sourceType.isDynamicDim(i))
2682 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2683 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2693static void createNewOperandWithStaticSizes(
2694 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2695 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2696 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2697 bool &changeNeeded) {
2698 Value src = opOperand->
get();
2699 newOperands.push_back(src);
2700 if (linalgOp.isScalar(opOperand))
2702 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2703 Type resultType = sourceType;
2704 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2705 resultTypes.push_back(resultType);
2708 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2709 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2710 SmallVector<int64_t> newShape;
2713 bool newOperandNeeded =
false;
2714 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2715 int64_t dimShape = sourceShape[i];
2716 AffineExpr dimExpr = sourceMap.
getResult(i);
2717 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2718 newShape.push_back(dimShape);
2724 newShape.push_back(affineExprToSize[dimExpr]);
2725 newOperandNeeded =
true;
2727 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2728 sourceType.getEncoding());
2729 if (newOperandNeeded) {
2730 changeNeeded =
true;
2733 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2735 newOperands[index] = newOperand;
2737 if (linalgOp.isDpsInit(opOperand))
2738 resultTypes.push_back(resultType);
2744struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2745 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2747 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2748 PatternRewriter &rewriter)
const override {
2749 if (!linalgOp.hasPureTensorSemantics())
2753 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2754 return !map.isProjectedPermutation();
2759 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2760 Location loc = linalgOp.getLoc();
2764 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2766 SmallVector<Value> newOperands;
2767 SmallVector<Type> resultTypes;
2771 bool changeNeeded =
false;
2772 newOperands.reserve(linalgOp->getNumOperands());
2773 resultTypes.reserve(linalgOp.getNumDpsInits());
2776 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2777 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2778 affineExprToSize, linalgOp, newOperands,
2779 resultTypes, changeNeeded);
2788 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2789 SmallVector<Value> replacements;
2791 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2792 Value newResult = std::get<1>(it);
2793 Value oldResult = std::get<0>(it);
2794 Type newType = newResult.
getType();
2795 Type oldType = oldResult.
getType();
2796 replacements.push_back(
2797 (newType != oldType)
2798 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2801 rewriter.
replaceOp(linalgOp, replacements);
2815LogicalResult SoftmaxOp::verify() {
2816 ShapedType inputType = getInputOperandType();
2817 ShapedType outputType = getOutputOperandType();
2819 ArrayRef<int64_t> inputShape = inputType.getShape();
2820 ArrayRef<int64_t> outputShape = outputType.getShape();
2824 int64_t inputRank = getInputOperandRank();
2825 int64_t dimension = getDimension();
2826 if ((dimension < 0) || (dimension >= inputRank))
2827 return emitOpError(
"incorrect dimension specified");
2832SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2833 int64_t operandRank = getInputOperandRank();
2834 SmallVector<Range> loopBounds(operandRank);
2835 Location loc = getLoc();
2838 Value source = getInput();
2839 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2840 loopBounds[dim].offset = zero;
2841 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2842 loopBounds[dim].stride = one;
2847SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2848 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2849 utils::IteratorType::parallel);
2850 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2851 return iteratorTypes;
2854FailureOr<TilingResult>
2855SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2856 ArrayRef<OpFoldResult> offsets,
2857 ArrayRef<OpFoldResult> sizes) {
2858 int64_t rank = getInputOperandRank();
2860 SmallVector<OpFoldResult> strides(rank, oneAttr);
2861 SmallVector<Value> tiledOperands;
2862 Operation *inputSlice =
2863 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2865 return emitOpError(
"failed to compute input slice");
2867 tiledOperands.emplace_back(inputSlice->
getResult(0));
2868 Operation *outputSlice =
2869 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2871 return emitOpError(
"failed to compute output slice");
2873 tiledOperands.emplace_back(outputSlice->
getResult(0));
2875 SmallVector<Type, 4> resultTypes;
2876 if (hasPureTensorSemantics())
2877 resultTypes.push_back(tiledOperands[1].
getType());
2878 Operation *tiledOp =
2879 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2881 return TilingResult{
2884 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2887LogicalResult SoftmaxOp::getResultTilePosition(
2888 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2889 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2890 SmallVector<OpFoldResult> &resultSizes) {
2891 if (resultNumber == 0) {
2892 resultOffsets.assign(offsets.begin(), offsets.end());
2893 resultSizes.assign(sizes.begin(), sizes.end());
2900LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2905SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2907 SmallVector<OpFoldResult> shapes;
2908 Location loc = getOperation()->getLoc();
2909 IRRewriter rewriter(
b);
2910 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2911 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2912 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2913 if (!outputShapedType.isDynamicDim(dim)) {
2915 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2922 reifiedReturnShapes.emplace_back(std::move(shapes));
2926void SoftmaxOp::getEffects(
2927 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2929 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2930 if (!llvm::isa<MemRefType>(operand.
getType()))
2933 &getOperation()->getOpOperand(index), 0,
2938 for (OpOperand &operand : getDpsInitsMutable()) {
2939 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2970static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2972 int64_t dim,
bool allParallel =
false) {
2974 utils::IteratorType::parallel);
2976 iteratorTypes[dim] = utils::IteratorType::reduction;
2980 for (
int i = 0; i < inputRank; i++) {
2987 return std::make_tuple(iteratorTypes, indexingMaps);
2992template <
typename T>
2995 auto inputType = cast<ShapedType>(input.
getType());
2997 int64_t inputRank = inputShape.size();
2998 auto [iteratorTypes, indexingMaps] =
3000 assert(indexingMaps.size() == 2 &&
3001 "We should have two maps: 1 for the input, 1 for the output");
3002 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3004 auto genericOp = linalg::GenericOp::create(
3005 builder, loc, output.
getType(), input, output, indexingMaps,
3007 Value result = T::create(b, loc, args[0], args[1]);
3008 linalg::YieldOp::create(b, loc, result);
3010 return genericOp.getResult(0);
3018 auto inputType = cast<ShapedType>(input.
getType());
3020 int64_t inputRank = inputShape.size();
3022 builder, inputRank, dim,
true);
3023 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3024 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3026 indexingMaps.push_back(indexingMaps[0]);
3027 auto genericOp = linalg::GenericOp::create(
3029 indexingMaps, iteratorTypes,
3031 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3032 Value result = math::ExpOp::create(b, loc, diff);
3033 linalg::YieldOp::create(b, loc, result);
3035 return genericOp.getResult(0);
3045 auto inputType = cast<ShapedType>(numerator.
getType());
3047 int64_t inputRank = inputShape.size();
3049 builder, inputRank, dim,
true);
3050 assert(indexingMaps.size() == 2 &&
3051 "We should have one map for each input (2)");
3052 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3054 indexingMaps.push_back(indexingMaps[0]);
3055 auto genericOp = linalg::GenericOp::create(
3057 output, indexingMaps, iteratorTypes,
3059 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3060 linalg::YieldOp::create(b, loc, result);
3062 return genericOp.getResult(0);
3084FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3085 OpBuilder::InsertionGuard guard(
b);
3086 b.setInsertionPoint(*
this);
3087 Location loc = getLoc();
3088 Value input = getInput();
3089 ShapedType inputType = getInputOperandType();
3090 Type elementType = inputType.getElementType();
3091 int64_t reductionDim = getDimension();
3093 Value output = getOutput();
3094 dims.erase(dims.begin() + reductionDim);
3096 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3098 elementType,
b, loc,
3100 Value neutralForMaxFInit =
3101 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3113 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3119 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3120 return SmallVector<Value>{
result};
3127LogicalResult WinogradFilterTransformOp::verify() {
3128 auto filterType = cast<ShapedType>(getFilter().
getType());
3129 ArrayRef<int64_t> filterShape = filterType.getShape();
3130 int64_t filterH = filterShape[getFilterHDim()];
3131 int64_t filterW = filterShape[getFilterWDim()];
3132 WinogradConv2DFmr fmr = getFmr();
3136 if (filterH != r && filterH != 1)
3137 return emitOpError(
"expect filter height either equals to r or 1");
3138 if (filterW != r && filterW != 1)
3139 return emitOpError(
"expect filter width either equals to r or 1");
3140 if (filterH == 1 && filterW == 1)
3141 return emitOpError(
"expect either filter height or width equals to r");
3143 SmallVector<int64_t> expectedOutputShape;
3144 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3145 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3146 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3147 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3149 auto outputType = cast<ShapedType>(getOutput().
getType());
3150 ArrayRef<int64_t> outputShape = outputType.getShape();
3152 return emitOpError(
"the output shape is not expected");
3158WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3159 Location loc = getLoc();
3162 Value filter = getFilter();
3163 int64_t filterRank = getFilterOperandRank();
3164 SmallVector<Range> loopBounds(filterRank);
3165 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3166 loopBounds[dim].offset = zeroAttr;
3167 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3168 loopBounds[dim].stride = oneAttr;
3173SmallVector<utils::IteratorType>
3174WinogradFilterTransformOp::getLoopIteratorTypes() {
3175 int64_t filterRank = getFilterOperandRank();
3176 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3177 utils::IteratorType::parallel);
3178 return iteratorTypes;
3181LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3182 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3183 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3184 SmallVector<OpFoldResult> &resultSizes) {
3186 ShapedType filterType = getFilterOperandType();
3187 ArrayRef<int64_t> filterShape = filterType.getShape();
3188 int64_t filterH = filterShape[getFilterHDim()];
3189 int64_t filterW = filterShape[getFilterWDim()];
3190 WinogradConv2DFmr fmr = getFmr();
3193 int64_t alpha = m + r - 1;
3194 int64_t alphaH = filterH != 1 ? alpha : 1;
3195 int64_t alphaW = filterW != 1 ? alpha : 1;
3199 resultOffsets.append(
3200 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3202 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3213FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3214 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3215 ArrayRef<OpFoldResult> sizes) {
3218 ShapedType filterType = getFilterOperandType();
3219 ArrayRef<int64_t> filterShape = filterType.getShape();
3220 int64_t filterH = filterShape[getFilterHDim()];
3221 int64_t filterW = filterShape[getFilterWDim()];
3224 SmallVector<Value> tiledOperands;
3225 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3227 sliceOffsets.append(
3228 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3229 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3230 sizes[getFilterCDim()]});
3231 int64_t filterRank = getFilterOperandRank();
3232 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3233 Location loc = getLoc();
3234 auto filterSlice = tensor::ExtractSliceOp::create(
3235 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3236 tiledOperands.emplace_back(filterSlice);
3238 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3243 int64_t outputRank = getOutputOperandRank();
3244 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3245 auto outputSlice = tensor::ExtractSliceOp::create(
3246 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3247 tiledOperands.emplace_back(outputSlice);
3249 SmallVector<Type> resultTypes;
3250 resultTypes.push_back(tiledOperands[1].
getType());
3251 Operation *tiledOp =
3252 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3254 return TilingResult{
3257 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3264LogicalResult WinogradInputTransformOp::verify() {
3265 auto inputType = cast<ShapedType>(getInput().
getType());
3266 ArrayRef<int64_t> inputShape = inputType.getShape();
3267 int64_t inputH = inputShape[getInputHDim()];
3268 int64_t inputW = inputShape[getInputWDim()];
3269 WinogradConv2DFmr fmr = getFmr();
3272 int64_t tileSize = m + r - 1;
3274 auto outputType = cast<ShapedType>(getOutput().
getType());
3275 ArrayRef<int64_t> outputShape = outputType.getShape();
3276 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3277 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3279 SmallVector<int64_t> expectedOutputShape(6, inputH);
3280 if (ShapedType::isDynamic(inputH)) {
3281 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3282 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3284 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3285 expectedOutputShape[getOutputTileHDim()] =
3286 leftTransform ? (inputH - (r - 1)) / m : inputH;
3288 if (ShapedType::isDynamic(inputW)) {
3289 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3290 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3292 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3293 expectedOutputShape[getOutputTileWDim()] =
3294 rightTransform ? (inputW - (r - 1)) / m : inputW;
3296 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3297 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3300 return emitOpError(
"the output shape is not expected");
3306WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3307 Location loc = getLoc();
3310 Value output = getOutput();
3311 int64_t outputRank = getOutputOperandRank();
3312 SmallVector<Range> loopBounds(outputRank);
3313 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3314 loopBounds[dim].offset = zeroAttr;
3316 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3317 loopBounds[dim].stride = oneAttr;
3322SmallVector<utils::IteratorType>
3323WinogradInputTransformOp::getLoopIteratorTypes() {
3324 int64_t outputRank = getOutputOperandRank();
3325 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3326 utils::IteratorType::parallel);
3327 return iteratorTypes;
3330LogicalResult WinogradInputTransformOp::getResultTilePosition(
3331 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3332 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3333 SmallVector<OpFoldResult> &resultSizes) {
3335 ShapedType outputType = getOutputOperandType();
3336 ArrayRef<int64_t> outputShape = outputType.getShape();
3337 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3338 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3340 WinogradConv2DFmr fmr = getFmr();
3343 int64_t alpha = m + r - 1;
3344 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3345 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3350 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3351 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3352 offsets[getOutputCDim()]});
3353 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3354 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3355 sizes[getOutputCDim()]});
3366FailureOr<TilingResult>
3367WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3368 ArrayRef<OpFoldResult> offsets,
3369 ArrayRef<OpFoldResult> sizes) {
3371 WinogradConv2DFmr fmr = getFmr();
3375 ShapedType outputType = getOutputOperandType();
3376 ArrayRef<int64_t> outputShape = outputType.getShape();
3377 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3378 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3380 Location loc = getLoc();
3382 auto identityAffineMap =
3384 auto offsetAffineMap =
3387 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3388 offsets[getOutputTileHDim()]);
3390 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3391 offsets[getOutputTileWDim()]);
3395 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3397 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3399 SmallVector<Value> tiledOperands;
3400 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3402 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3403 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3404 sliceOffsets.append(
3405 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3406 OpFoldResult sizeH =
3407 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3408 OpFoldResult sizeW =
3409 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3411 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3412 int64_t inputRank = getInputOperandRank();
3413 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3414 auto inputSlice = tensor::ExtractSliceOp::create(
3415 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3416 tiledOperands.emplace_back(inputSlice);
3418 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3423 int64_t outputRank = getOutputOperandRank();
3424 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3425 auto outputSlice = tensor::ExtractSliceOp::create(
3426 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3427 tiledOperands.emplace_back(outputSlice);
3429 SmallVector<Type> resultTypes;
3430 resultTypes.push_back(tiledOperands[1].
getType());
3431 Operation *tiledOp =
3432 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3434 return TilingResult{
3437 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3444LogicalResult WinogradOutputTransformOp::verify() {
3445 auto valueType = cast<ShapedType>(getValue().
getType());
3446 ArrayRef<int64_t> valueShape = valueType.getShape();
3447 int64_t valueH = valueShape[getValueAlphaHDim()];
3448 int64_t valueW = valueShape[getValueAlphaWDim()];
3449 int64_t valueTileH = valueShape[getValueTileHDim()];
3450 int64_t valueTileW = valueShape[getValueTileWDim()];
3451 WinogradConv2DFmr fmr = getFmr();
3454 bool leftTransform = valueH != 1;
3455 bool rightTransform = valueW != 1;
3457 int64_t outputRank = getOutputOperandRank();
3458 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3459 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3460 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3462 if (valueH != (leftTransform ? m + r - 1 : 1))
3463 return emitOpError(
"expect input height equals to input tile size");
3464 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3466 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3467 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3469 if (valueW != (rightTransform ? m + r - 1 : 1))
3470 return emitOpError(
"expect input width equals to input tile size");
3471 expectedOutputShape[getOutputWDim()] =
3472 (rightTransform ? m : 1) * valueTileW;
3474 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3475 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3477 auto outputType = cast<ShapedType>(getOutput().
getType());
3478 ArrayRef<int64_t> outputShape = outputType.getShape();
3480 return emitOpError(
"the output shape is not expected");
3486WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3487 Location loc = getLoc();
3490 Value value = getValue();
3491 int64_t valueRank = getValueOperandRank();
3492 SmallVector<Range> loopBounds(valueRank);
3493 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3494 loopBounds[dim].offset = zeroAttr;
3496 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3497 loopBounds[dim].stride = oneAttr;
3502SmallVector<utils::IteratorType>
3503WinogradOutputTransformOp::getLoopIteratorTypes() {
3504 int64_t valueRank = getValueOperandRank();
3505 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3506 utils::IteratorType::parallel);
3507 return iteratorTypes;
3510LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3511 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3512 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3513 SmallVector<OpFoldResult> &resultSizes) {
3514 WinogradConv2DFmr fmr = getFmr();
3518 Location loc = getLoc();
3520 auto identityAffineMap =
3525 ShapedType valueType = getValueOperandType();
3526 ArrayRef<int64_t> valueShape = valueType.getShape();
3527 int64_t valueH = valueShape[0];
3528 int64_t valueW = valueShape[1];
3530 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3531 offsets[getValueTileHDim()]);
3533 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3534 offsets[getValueTileWDim()]);
3536 builder, loc, affineMap, sizes[getValueTileHDim()]);
3538 builder, loc, affineMap, sizes[getValueTileWDim()]);
3541 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3542 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3543 OpFoldResult sizeH =
3544 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3545 OpFoldResult sizeW =
3546 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3548 resultOffsets.append(
3549 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3551 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3561FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3562 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3563 ArrayRef<OpFoldResult> sizes) {
3566 Location loc = getLoc();
3567 SmallVector<Value> tiledOperands;
3568 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3570 ShapedType valueType = getValueOperandType();
3571 ArrayRef<int64_t> valueShape = valueType.getShape();
3572 int64_t alphaH = valueShape[getValueAlphaHDim()];
3573 int64_t alphaW = valueShape[getValueAlphaWDim()];
3577 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3578 offsets[getValueTileWDim()], offsets[getValueNDim()],
3579 offsets[getValueFDim()]});
3580 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3581 sizes[getValueTileWDim()], sizes[getValueNDim()],
3582 sizes[getValueFDim()]});
3583 int64_t valueRank = getValueOperandRank();
3584 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3585 auto valueSlice = tensor::ExtractSliceOp::create(
3586 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3587 tiledOperands.emplace_back(valueSlice);
3589 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3594 int64_t outputRank = getOutputOperandRank();
3595 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3596 auto outputSlice = tensor::ExtractSliceOp::create(
3597 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3598 tiledOperands.emplace_back(outputSlice);
3600 SmallVector<Type> resultTypes;
3601 resultTypes.push_back(tiledOperands[1].
getType());
3602 Operation *tiledOp =
3603 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3605 return TilingResult{
3608 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3622 llvm::set_union(explicitSet, defaultSet);
3623 return explicitSet == defaultSet;
3643 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3645 auto opIndexingMap = opIndexingMaps[opIndex];
3646 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3649 return matmulOp->emitOpError()
3650 <<
"Unexpected dim expression in map result.";
3653 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3654 return matmulOp->emitOpError()
3655 <<
"Invalid broadcast requested, should be (d2).";
3664template <
typename OpTy>
3667 AffineMap defaultIndexingMap,
bool isLHS) {
3668 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3669 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3670 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3673 return batchVariantMatmulOp->emitOpError()
3674 <<
"Unexpected result dim expression (outside the set of default "
3679 return batchVariantMatmulOp->emitOpError()
3680 <<
"no. of result dim expressions exceeds 3.";
3682 auto hasValidBatchDim = [](
AffineMap map) {
3689 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3690 return batchVariantMatmulOp->emitOpError()
3691 <<
"Invalid broadcast requested.";
3692 }
else if (!hasValidBatchDim(opIndexingMap)) {
3693 return batchVariantMatmulOp->emitOpError()
3694 <<
"Invalid batch dimension expression.";
3702template <
typename OpTy>
3705 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3706 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3707 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3708 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3711 return batchVariantMatmulOp->emitOpError()
3712 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3715 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3717 return batchVariantMatmulOp->emitOpError()
3718 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3722 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3723 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3724 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3725 outputMap.getResult(1).isFunctionOfDim(1) &&
3726 outputMap.getResult(2).isFunctionOfDim(2)
3727 : outputMap.getResult(0).isFunctionOfDim(1) &&
3728 outputMap.getResult(1).isFunctionOfDim(2);
3731 if (!areValidOutputResultDim(opIndexingMap)) {
3732 return batchVariantMatmulOp->emitOpError()
3733 <<
"Invalid output map result dimension.";
3742template <
typename OpTy>
3747 batchVariantMatmulOp.getIndexingMapsArray();
3749 batchVariantMatmulOp.getDefaultIndexingMaps(
3750 batchVariantMatmulOp->getContext());
3752 if (opIndexingMaps.size() != 3)
3753 return batchVariantMatmulOp->emitOpError()
3754 <<
"Indexing_map attribute must have 3 affine maps.";
3756 auto opIndexingMap = opIndexingMaps[opIndex];
3757 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3765 defaultIndexingMap, opIndex == 0)))
3775 if (m == 2 && r == 3)
3776 return WinogradConv2DFmr::F_2_3;
3777 if (m == 4 && r == 3)
3778 return WinogradConv2DFmr::F_4_3;
3779 if (m == 2 && r == 5)
3780 return WinogradConv2DFmr::F_2_5;
3781 return std::nullopt;
3786 case WinogradConv2DFmr::F_2_3:
3788 case WinogradConv2DFmr::F_4_3:
3790 case WinogradConv2DFmr::F_2_5:
3793 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3800static FailureOr<SmallVector<SmallVector<int64_t>>>
3803 for (
auto map : maps) {
3804 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3808 for (
auto result : attr.getAffineMap().getResults()) {
3809 auto dim = dyn_cast<AffineDimExpr>(
result);
3812 pos.push_back(dim.getPosition());
3814 positions.push_back(pos);
3827 return indexingMaps;
3830bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3831 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3834 if (maps.size() != 3)
3839 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3840 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3841 (*positions)[2] == SmallVector<int64_t>{0, 1};
3844SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3845 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3846 utils::IteratorType::parallel,
3847 utils::IteratorType::reduction};
3850unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3852std::string MatmulOp::getLibraryCallName() {
3856bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3860bool MatmulOp::hasUserDefinedMaps() {
3861 SmallVector<AffineMap, 3> defaultMaps =
3863 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3864 return defaultMaps != explicitMaps;
3869void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3870 ArrayRef<NamedAttribute> attrs,
3873 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3878 "MatmulOp regionBuilder expects 3 args");
3879 RegionBuilderHelper helper(
b, block);
3880 SmallVector<Value> yields;
3882 TypeFn castVal = TypeFn::cast_signed;
3883 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3884 return attr.
getName() ==
"cast";
3886 if (castIter != attrs.end()) {
3887 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3895 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3896 if (!value1 || !value2 || !value3)
3898 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3902 yields.push_back(value4);
3903 helper.yieldOutputs(yields);
3913bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3914 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3915 AffineExpr expr = bcastMap.
getResult(0);
3925 ArrayAttr arrayAttr;
3929 if (llvm::any_of(arrayAttr,
3930 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3932 <<
"element of indexing_maps array is not an affine_map";
3939 if (failed(indexingMapsAttr))
3942 if (*indexingMapsAttr ==
nullptr) {
3943 auto indexingMapAttrs = llvm::map_to_vector(
3944 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3949 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3951 MatmulOp::getRegionBuilder());
3954void MatmulOp::print(OpAsmPrinter &p) {
3955 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3956 MatmulOp::getDefaultIndexingMaps(
getContext()),
3957 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3958 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3959 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3961 std::array<StringRef, 3> elidedAttrs = {
3962 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3968LogicalResult MatmulOp::verify() {
3970 if (!hasUserDefinedMaps())
3973 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3980LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3984void MatmulOp::getEffects(
3985 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3987 if (hasPureTensorSemantics())
3996SmallVector<AffineMap>
3997MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3998 AffineExpr d0, d1, d2;
4004 return {mapLHS, mapRHS, mapOut};
4008 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4011 if (maps.size() != 3)
4014 if (failed(positions))
4026 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4034 build(builder, state, inputs, outputs, attributes);
4035 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4036 assert(res &&
"builder didn't return the right type");
4046 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4055 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4056 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4057 assert(res &&
"builder didn't return the right type");
4067 result.addAttribute(
"cast", cast);
4069 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4078 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4079 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4080 assert(res &&
"builder didn't return the right type");
4085 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4087 op->
getAttr(
"indexing_maps"));
4091MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4098 return {mapLHS, mapRHS, mapOut};
4102 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4105 if (maps.size() != 3)
4108 if (failed(positions))
4120 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4128 build(builder, state, inputs, outputs, attributes);
4129 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4130 assert(res &&
"builder didn't return the right type");
4140 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4149 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4150 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4151 assert(res &&
"builder didn't return the right type");
4161 result.addAttribute(
"cast", cast);
4163 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4172 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4173 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4174 assert(res &&
"builder didn't return the right type");
4179 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4181 op->
getAttr(
"indexing_maps"));
4185BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4192 return {mapLHS, mapRHS, mapOut};
4196 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4199 if (maps.size() != 3)
4202 if (failed(positions))
4213 BatchMatmulOp::getRegionBuilder(),
4214 getDefaultIndexingMaps(builder));
4222 build(builder, state, inputs, outputs, attributes);
4223 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4224 assert(res &&
"builder didn't return the right type");
4233 BatchMatmulOp::getRegionBuilder(),
4234 getDefaultIndexingMaps(builder));
4243 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4244 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4245 assert(res &&
"builder didn't return the right type");
4253 result.addAttribute(
"cast", cast);
4255 BatchMatmulOp::getRegionBuilder(),
4256 getDefaultIndexingMaps(builder));
4265 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4266 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4267 assert(res &&
"builder didn't return the right type");
4272 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4274 op->
getAttr(
"indexing_maps"));
4278BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4285 return {mapLHS, mapRHS, mapOut};
4289 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4292 if (maps.size() != 3)
4295 if (failed(positions))
4306 BatchMatmulOp::getRegionBuilder(),
4307 getDefaultIndexingMaps(builder));
4315 build(builder, state, inputs, outputs, attributes);
4316 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4317 assert(res &&
"builder didn't return the right type");
4326 BatchMatmulOp::getRegionBuilder(),
4327 getDefaultIndexingMaps(builder));
4336 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4337 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4338 assert(res &&
"builder didn't return the right type");
4346 result.addAttribute(
"cast", cast);
4348 BatchMatmulOp::getRegionBuilder(),
4349 getDefaultIndexingMaps(builder));
4358 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4359 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4360 assert(res &&
"builder didn't return the right type");
4365 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4367 op->
getAttr(
"indexing_maps"));
4375 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4386 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4387 assert(dimExpr &&
"affine_map is a projected permutation");
4388 dimsInOutput[dimExpr.getPosition()] =
true;
4392 for (
auto dimOccursInOutput : dimsInOutput)
4393 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4394 : utils::IteratorType::reduction);
4396 return iteratorTypes;
4399unsigned ContractOp::getNumRegionArgs() {
return 3; }
4402void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4403 ArrayRef<NamedAttribute> attrs,
4406 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4411 "ContractOp regionBuilder expects 3 args");
4412 RegionBuilderHelper helper(
b, block);
4414 TypeFn castSignedness = TypeFn::cast_signed;
4415 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4416 return attr.
getName() ==
"cast";
4418 if (castIter != attrs.end()) {
4419 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4425 Value lhsAtOutType =
4426 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4427 Value rhsAtOutType =
4428 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4429 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4431 if (!productAtOutType)
4437 helper.yieldOutputs({
result});
4440ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4442 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4444 "expected 'indexing_maps' attribute");
4445 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4451void ContractOp::print(OpAsmPrinter &p) {
4452 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4454 p, getOperation(), getInputs(), getOutputs(),
4455 {
"indexing_maps",
"operandSegmentSizes"});
4458LogicalResult ContractOp::verify() {
4459 int iterationSpaceDims = -1;
4464 SmallVector<size_t> inOccurrences;
4465 SmallVector<size_t> outOccurrences;
4468 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4469 bool isInput) -> LogicalResult {
4472 return emitError(
"provided affine_map is not a projected permutation");
4475 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4477 return emitError(
"ranks of shaped operand and results of corresponding "
4478 "affine_map differ");
4480 return emitError(
"affine_map specifies shaped access while operand has "
4485 if (iterationSpaceDims == -1) {
4487 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4488 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4489 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4490 return emitError(
"iteration spaces of provided affine_maps differ");
4494 for (AffineExpr affineExpr : affineMap.
getResults()) {
4495 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4497 llvm_unreachable(
"affine_map is a projected permutation");
4500 inOccurrences[affineDimExpr.getPosition()] += 1;
4502 outOccurrences[affineDimExpr.getPosition()] += 1;
4508 for (
auto &&[affineMap, operandType, isInput] :
4509 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4510 SmallVector<bool>{
true,
true,
false})) {
4511 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4515 bool hasContractingDim =
false;
4516 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4517 size_t inOccCount = inOccurrences[dimIndex];
4518 size_t outOccCount = outOccurrences[dimIndex];
4521 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4523 if (inOccCount == 0 && outOccCount == 0)
4524 return emitError() <<
"iteration space dim at index " << dimIndex
4525 <<
" not used to access any operand";
4536 if (inOccCount == 1 && outOccCount != 1)
4538 <<
"iteration space dim at index " << dimIndex
4539 <<
" is neither a contracting dim nor of parallel iteration type";
4542 if (!hasContractingDim)
4543 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4548LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4552void ContractOp::getEffects(
4553 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4555 if (hasPureTensorSemantics())
4567SmallVector<AffineMap>
4568BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4569 AffineExpr d0, d1, d2, d3;
4570 SmallVector<AffineMap> indexingMaps;
4572 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4573 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4574 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4575 return indexingMaps;
4578bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4579 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4582 if (maps.size() != 3)
4587 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4588 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4589 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4592SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4593 return SmallVector<utils::IteratorType>{
4594 utils::IteratorType::parallel, utils::IteratorType::parallel,
4595 utils::IteratorType::parallel, utils::IteratorType::reduction};
4598unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4600std::string BatchMatmulOp::getLibraryCallName() {
4606bool BatchMatmulOp::hasUserDefinedMaps() {
4607 SmallVector<AffineMap, 3> defaultMaps =
4609 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4610 return defaultMaps != explicitMaps;
4620bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4622 "Expected less than 3 result dim expr.");
4623 bool isValid =
false;
4624 enum Indices { batchPos, mPos, nPos, kPos };
4626 AffineExpr expr = bcastMap.
getResult(0);
4629 AffineExpr expr0 = bcastMap.
getResult(0);
4630 AffineExpr expr1 = bcastMap.
getResult(1);
4635 : ((expr0.isFunctionOfDim(batchPos) &&
4636 expr1.isFunctionOfDim(kPos)) ||
4637 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4642void BatchMatmulOp::regionBuilder(
4643 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4646 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4651 "BatchMatmulOp regionBuilder expects 3 args");
4652 RegionBuilderHelper helper(
b, block);
4653 SmallVector<Value> yields;
4655 TypeFn castVal = TypeFn::cast_signed;
4656 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4657 return attr.
getName() ==
"cast";
4659 if (castIter != attrs.end()) {
4660 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4665 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4666 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4668 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
4669 if (!castValA || !castValB || !mulVal)
4671 Value addVal = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
4675 yields.push_back(addVal);
4676 helper.yieldOutputs(yields);
4679ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4680 SmallVector<Attribute, 3> indexingMapsAttr;
4692 if (!isa<AffineMapAttr>(mapAttr)) {
4694 "expected affine map attribute");
4696 indexingMapsAttr.push_back(mapAttr);
4706 if (indexingMapsAttr.empty()) {
4707 indexingMapsAttr = llvm::map_to_vector(
4708 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4709 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4711 result.addAttribute(
"indexing_maps",
4714 return ::parseNamedStructuredOp(parser,
result,
4715 BatchMatmulOp::getNumRegionArgs(),
4716 BatchMatmulOp::getRegionBuilder());
4719void BatchMatmulOp::print(OpAsmPrinter &p) {
4720 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4721 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4722 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4723 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4724 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4726 std::array<StringRef, 3> elidedAttrs = {
4727 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4733LogicalResult BatchMatmulOp::verify() {
4736 if (!hasUserDefinedMaps())
4739 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4746LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4747 SmallVectorImpl<OpFoldResult> &) {
4751void BatchMatmulOp::getEffects(
4752 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4754 if (hasPureTensorSemantics())
4768struct ArityGroupAndKind {
4770 ElementwiseArityGroup arityGroup;
4776 TernaryFn ternaryFn;
4780unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4781 return static_cast<unsigned>(arityGroup);
4786 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4787 constexpr int lastBinary =
4788 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4789 constexpr int lastTernary =
4790 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4792 int val =
static_cast<int>(kind);
4793 ArityGroupAndKind
result;
4795 if (val < lastUnary) {
4796 result.arityGroup = ElementwiseArityGroup::Unary;
4797 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4800 if (val < lastBinary) {
4801 result.arityGroup = ElementwiseArityGroup::Binary;
4802 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4805 if (val >= lastTernary) {
4806 llvm_unreachable(
"unhandled ElementwiseFn");
4808 result.arityGroup = ElementwiseArityGroup::Ternary;
4809 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4814 auto rank = getResultRank();
4819ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4825ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4828 mlir::linalg::ElementwiseKind elemwiseKindVal;
4833 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4834 if (!elemwiseKindAttr)
4836 "expected ElementwiseKind attribute");
4837 elemwiseKindVal = elemwiseKindAttr.getValue();
4840 "expected operation 'kind' attribute");
4843 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4846 SmallVector<Attribute, 3> indexingMapsAttr;
4856 if (!isa<AffineMapAttr>(mapAttr))
4858 "expected affine map attribute");
4859 indexingMapsAttr.push_back(mapAttr);
4870 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4872 ElementwiseOp::getRegionBuilder())) {
4874 "unable to parse elemwise op");
4878 if (indexingMapsAttr.empty()) {
4881 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4882 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4885 "return type needs to be shaped type");
4886 auto numDims = shapedType.getRank();
4887 indexingMapsAttr = llvm::map_to_vector(
4888 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4890 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4893 result.addAttribute(
"indexing_maps",
4898void ElementwiseOp::print(OpAsmPrinter &p) {
4901 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4905 unsigned numDims = getResultRank();
4907 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4908 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4910 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4912 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4913 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4921void ElementwiseOp::regionBuilder(
4922 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4924 ElementwiseKind elemwiseKind;
4925 for (
auto attr : attrs) {
4926 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4927 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4928 assert(kindAttr &&
"op kind attribute incorrectly set");
4929 elemwiseKind = kindAttr.getValue();
4935 auto arityGroup = groupAndKind.arityGroup;
4936 auto kind = groupAndKind.kind;
4938 getArityGroupAsUInt(arityGroup) + 1 ) {
4939 emitError() <<
"Elementwise regionBuilder expects "
4940 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4945 getArityGroupAsUInt(arityGroup) + 1
4946 &&
"Elementwise regionBuilder number of block args mismatch");
4948 RegionBuilderHelper helper(
b, block);
4949 SmallVector<Value> yields;
4952 if (arityGroup == ElementwiseArityGroup::Unary) {
4955 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4959 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4964 assert(
false &&
"found unhandled category in elemwise");
4967 yields.push_back(
result);
4968 helper.yieldOutputs(yields);
4971LogicalResult ElementwiseOp::fold(FoldAdaptor,
4972 SmallVectorImpl<OpFoldResult> &) {
4976void ElementwiseOp::getEffects(
4977 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4979 if (hasPureTensorSemantics())
4992template <
typename OpTy,
typename>
4995 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4996 ? packOrUnPack.getDestType()
4997 : packOrUnPack.getSourceType();
4998 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4999 ? packOrUnPack.getSourceType()
5000 : packOrUnPack.getDestType();
5002 packedType.getShape().take_front(unpackedType.getRank()));
5003 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5025 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5027 .take_back(mixedTiles.size()),
5029 int64_t dimSize = std::get<0>(it);
5030 if (dimSize == ShapedType::kDynamic) {
5031 newMixedTileSizes.push_back(std::get<1>(it));
5038 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5040 newMixedTileSizes.push_back(
tile);
5043 "tile size and dim size don't match!");
5044 newMixedTileSizes.push_back(
5049 return newMixedTileSizes;
5052template <
typename OpTy>
5056 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5057 "applies to only pack or unpack operations");
5058 int64_t destRank = op.getDestRank();
5060 for (
auto dim : llvm::seq<int64_t>(0, destRank))
5061 reifiedReturnShapes[0][dim] =
5066template <
typename OpTy>
5068 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5069 "applies to only pack or unpack operations");
5073 assert(tiles.size() == dimsToTile.size() &&
5074 "tiles must match indices of dimension to block");
5076 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5077 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5078 return dimAndTileMapping;
5081template <
typename OpTy>
5083 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5084 "applies to only pack or unpack operations");
5087 unsigned dynamicValIndex = 0;
5088 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5089 if (ShapedType::isStatic(staticTile))
5092 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5094 return mixedInnerTiles;
5097template <
typename OpTy>
5099 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5100 "applies to only pack or unpack operations");
5113 size_t dimsPosSize = dimsPos.size();
5114 if (dimsPosSize > rank)
5117 if (dimsPosSize != uniqued.size())
5119 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5120 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5124template <
typename OpTy>
5126 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5127 "applies to only pack or unpack operations");
5128 Operation *op = packOrUnPack.getOperation();
5138 if (!packOrUnPack.getSourceType().hasRank() ||
5139 !packOrUnPack.getDestType().hasRank())
5140 return op->
emitError(
"expected both source and destination to have rank");
5143 if (!packOrUnPack.hasPureBufferSemantics() &&
5144 !packOrUnPack.hasPureTensorSemantics())
5145 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5146 const unsigned numResults = packOrUnPack.getNumResults();
5147 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5148 return op->
emitError(
"expected 1 result, got ") << numResults;
5149 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5150 return op->
emitError(
"expected 0 results, got ") << numResults;
5154 if (hasZeros(mixedTiles))
5155 return op->
emitError(
"invalid zero tile factor");
5158 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5159 ? packOrUnPack.getSourceType()
5160 : packOrUnPack.getDestType();
5161 size_t unpackedRank = unpackedType.getRank();
5165 return op->
emitError(
"invalid inner_dims_pos vector");
5167 return op->
emitError(
"invalid outer_dims_perm vector");
5168 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5169 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5173 if (mixedTiles.size() > unpackedRank) {
5174 return op->
emitError(
"tiling factors must be less than or equal to the "
5175 "input rank for pack or output rank for unpack");
5177 if (mixedTiles.size() != innerDimsPos.size()) {
5179 "tiling factors must equal the number of dimensions to tile");
5182 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5183 ? packOrUnPack.getDestType()
5184 : packOrUnPack.getSourceType();
5185 size_t packedRank = packedType.getRank();
5187 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5188 if (expectedPackedRank != packedRank) {
5190 "packed rank != (unpacked rank + num tiling factors), got ")
5191 << packedRank <<
" != " << expectedPackedRank;
5198 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5199 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5200 for (
auto it : llvm::enumerate(llvm::zip(
5201 packedType.getShape().take_back(mixedTiles.size()), mixedTiles))) {
5202 int64_t dimSize = std::get<0>(it.value());
5204 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it.value()))) {
5205 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5206 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5207 if (dimSize != staticTileSize)
5209 "mismatch in inner tile sizes specified and shaped of "
5210 "tiled dimension in the packed type at index ")
5211 << it.index() <<
": got " << dimSize <<
" != " << staticTileSize;
5212 }
else if (!ShapedType::isDynamic(dimSize)) {
5213 return op->
emitError(
"mismatch in inner tile sizes specified at index ")
5214 << it.index() <<
": got static shape " << dimSize
5215 <<
" but dynamic tile size";
5220 auto elementType = unpackedType.getElementType();
5221 Type expectedType, actualType;
5222 if (packOrUnPack.hasPureTensorSemantics()) {
5223 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5224 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5226 expectedType = MemRefType::get(expectedPackedShape, elementType);
5227 actualType = MemRefType::get(packedType.getShape(), elementType);
5230 << expectedType <<
" for the packed domain value, got "
5243struct PackOrUnPackTransposeResult {
5250template <
typename OpTy>
5251static PackOrUnPackTransposeResult
5255 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5256 "applies to only pack or unpack operations");
5257 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5258 "some permutation must be non-empty");
5259 PackOrUnPackTransposeResult metadata;
5260 metadata.innerDimsPos =
5262 metadata.innerTiles =
5264 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5265 ? packOrUnPackOp.getSourceRank()
5266 : packOrUnPackOp.getDestRank();
5267 metadata.outerDimsPerm =
5268 packOrUnPackOp.getOuterDimsPerm().empty()
5269 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5271 if (!innerPermutation.empty()) {
5272 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5274 "invalid inner permutation");
5278 if (!outerPermutation.empty()) {
5279 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5281 "invalid outer permutation");
5292 if (!getResults().empty())
5293 setNameFn(getResult(),
"pack");
5303 Type sourceType, destType, resultType;
5320 SmallVector<int64_t> outerDimsPermVec;
5323 if (parser.parseInteger(value))
5325 outerDimsPermVec.push_back(value);
5335 SmallVector<int64_t> innerDimsPosVec;
5338 if (parser.parseInteger(value))
5340 innerDimsPosVec.push_back(value);
5352 for (
auto val : staticTilesAttr.
asArrayRef())
5353 staticTiles.push_back(val);
5370 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5373 "pack/unpack requires '->' and destination type");
5377 resultType = destType;
5383 if (!paddingValue.empty() &&
5388 if (!dynamicTiles.empty() &&
5393 result.addAttribute(
"static_inner_tiles",
5395 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5397 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5399 SmallVector<int32_t> segmentSizes = {
5400 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5401 static_cast<int32_t
>(dynamicTiles.size())};
5402 result.addAttribute(
"operandSegmentSizes",
5406 result.addTypes(resultType);
5411void PackOp::print(OpAsmPrinter &p) {
5412 p <<
" " << getSource();
5414 if (getPaddingValue()) {
5415 p <<
" padding_value(" << getPaddingValue() <<
" : "
5416 << getPaddingValue().getType() <<
")";
5419 if (!getOuterDimsPerm().empty()) {
5420 p <<
" outer_dims_perm = [";
5421 llvm::interleaveComma(getOuterDimsPerm(), p);
5425 p <<
" inner_dims_pos = [";
5426 llvm::interleaveComma(getInnerDimsPos(), p);
5429 p <<
" inner_tiles = ";
5432 p <<
" into " << getDest();
5435 {
"static_inner_tiles",
"inner_dims_pos",
5436 "outer_dims_perm",
"operandSegmentSizes"});
5438 p <<
" : " << getSource().getType();
5439 p <<
" -> " << getDest().getType();
5442void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5443 Value dest, ArrayRef<int64_t> innerDimsPos,
5444 ArrayRef<OpFoldResult> innerTiles,
5445 std::optional<Value> paddingValue,
5446 ArrayRef<int64_t> outerDimsPerm) {
5447 assert(innerDimsPos.size() == innerTiles.size() &&
5448 "number of tile sizes specified must match the specified number of "
5449 "original dimensions to be tiled");
5450 SmallVector<int64_t> staticTileSizes;
5451 SmallVector<Value> dynamicTileSizes;
5453 build(builder, state, dest.
getType(), source, dest,
5454 paddingValue ? *paddingValue :
nullptr,
5455 outerDimsPerm.empty() ?
nullptr
5462PackOp::reifyResultShapes(OpBuilder &builder,
5471SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5475SmallVector<int64_t> PackOp::getStaticTiles() {
5479ArrayRef<int64_t> PackOp::getAllOuterDims() {
5480 ShapedType inputType = getSourceType();
5481 int64_t inputRank = inputType.getRank();
5482 return getDestType().getShape().take_front(inputRank);
5485SmallVector<int64_t> PackOp::getTiledOuterDims() {
5486 auto innerDimsPos = getInnerDimsPos();
5487 SmallVector<int64_t> outerDims(getAllOuterDims());
5488 SmallVector<int64_t> res;
5491 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5493 if (!outerDimPermInv.empty())
5497 for (
auto index : innerDimsPos)
5498 res.push_back(outerDims[index]);
5503bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5504 ArrayRef<int64_t> innerDimsPos,
5505 ArrayRef<int64_t> outputShape,
5506 ArrayRef<int64_t> outerDimsPerm,
5507 ArrayRef<OpFoldResult> innerTiles) {
5508 SmallVector<int64_t> outputTileSizes(
5509 outputShape.take_front(inputShape.size()));
5510 if (!outerDimsPerm.empty()) {
5511 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5512 "expected output and outer_dims_perm to have same size");
5516 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5517 if (ShapedType::isDynamic(inputShape[pos]))
5520 if (!constantTile) {
5521 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5522 (inputShape[pos] % outputTileSizes[pos] != 0))
5525 assert(*constantTile != 0 &&
"static tile size can't be zero");
5526 if (inputShape[pos] % (*constantTile) != 0) {
5534bool PackOp::requirePaddingValueStrict(ArrayRef<int64_t> inputShape,
5535 ArrayRef<int64_t> innerDimsPos,
5536 ArrayRef<int64_t> outputShape,
5537 ArrayRef<int64_t> outerDimsPerm,
5538 ArrayRef<OpFoldResult> innerTiles) {
5539 SmallVector<int64_t> outputTileSizes(
5540 outputShape.take_front(inputShape.size()));
5541 if (!outerDimsPerm.empty()) {
5542 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5543 "expected output and outer_dims_perm to have same size");
5547 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5548 if (ShapedType::isDynamic(inputShape[pos]) ||
5549 ShapedType::isDynamic(outputTileSizes[pos]))
5554 assert(*constantTile != 0 &&
"static tile size can't be zero");
5555 if (inputShape[pos] % (*constantTile) != 0)
5561LogicalResult PackOp::verify() {
5568 auto paddingValue = getPaddingValue();
5572 << getSourceType().getElementType()
5573 <<
" but got: " << paddingValue.getType();
5576 if (!paddingValue &&
5577 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5578 getDestType().
getShape(), getOuterDimsPerm(),
5581 "invalid tile factor or output size provided. Only full tiles are "
5582 "supported when padding_value is not set");
5589static SmallVector<int64_t>
5592 for (
auto o : ofrs) {
5594 if (llvm::dyn_cast_if_present<Value>(o))
5595 result.push_back(ShapedType::kDynamic);
5607 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5608 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5610 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5611 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5614 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5615 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5619 if (!outerDimsPerm.empty())
5623 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5627SmallVector<OpFoldResult> PackOp::getResultShape(
5628 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5629 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
5630 ArrayRef<int64_t> outerDimsPerm) {
5631 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5635 AffineExpr ceilDivExpr = s0.
ceilDiv(s1);
5636 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5638 builder, loc, ceilDivExpr,
5639 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5641 if (!outerDimsPerm.empty())
5643 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5645 SmallVector<int64_t> resultTypeShape =
5648 innerDimsPos, outerDimsPerm);
5654 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5655 if (ShapedType::isStatic(resultTypeShape[i]))
5664RankedTensorType PackOp::inferPackedTensorType(
5665 RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
5666 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
5667 SmallVector<int64_t> resultShape = inferPackedShape(
5668 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5669 return RankedTensorType::get(resultShape, sourceType.getElementType());
5672MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5673 ArrayRef<int64_t> innerTileSizes,
5674 ArrayRef<int64_t> innerDimsPos,
5675 ArrayRef<int64_t> outerDimsPerm) {
5676 SmallVector<int64_t> resultShape = inferPackedShape(
5677 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5678 return MemRefType::get(resultShape, sourceType.getElementType());
5681Value PackOp::createDestinationTensor(OpBuilder &
b, Location loc, Value source,
5682 ArrayRef<OpFoldResult> innerTileSizes,
5683 ArrayRef<int64_t> innerDimsPos,
5684 ArrayRef<int64_t> outerDimsPerm) {
5685 AffineExpr dim0, dim1;
5687 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5692 SmallVector<OpFoldResult> mixedSizes;
5693 for (
auto [index, value] : llvm::enumerate(
5694 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5695 if (ShapedType::isDynamic(value))
5696 mixedSizes.push_back(
5697 tensor::DimOp::create(
b, loc, source, index).getResult());
5699 mixedSizes.push_back(
b.getIndexAttr(value));
5701 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5702 int64_t dimPos = std::get<0>(it);
5703 OpFoldResult tileSize = std::get<1>(it);
5704 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5706 if (!outerDimsPerm.empty())
5709 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5710 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5711 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5714PackOp PackOp::createTransposedClone(OpBuilder &
b, Location loc,
5715 ArrayRef<int64_t> innerPermutation,
5716 ArrayRef<int64_t> outerPermutation) {
5718 *
this, innerPermutation, outerPermutation);
5719 Value transposedDest =
5720 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5721 metadata.innerDimsPos, metadata.outerDimsPerm);
5722 return PackOp::create(
b, loc, getSource(), transposedDest,
5723 metadata.innerDimsPos, metadata.innerTiles,
5724 getPaddingValue(), metadata.outerDimsPerm);
5727template <
typename OpTy>
5732 if (op.hasPureTensorSemantics())
5735 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5736 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5739 if (&opOperand == &op.getSourceMutable()) {
5743 }
else if (&opOperand == &op.getDestMutable()) {
5754void PackOp::getEffects(
5760void UnPackOp::getEffects(
5767template <
typename OpTy>
5769 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5770 "applies to only pack or unpack operations");
5771 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5773 : op.getSourceType();
5775 for (
auto [dimDest,
tile] : llvm::zip(
5776 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5778 if (!constTileSize || ShapedType::isDynamic(dimDest))
5785 if (!hasPureTensorSemantics())
5787 if (getPaddingValue())
5802 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5804 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5816 auto packTiles = packOp.getMixedTiles();
5817 auto unPackTiles = unPackOp.getMixedTiles();
5818 if (packTiles.size() != unPackTiles.size())
5820 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5829 auto srcType = op.getSourceType();
5830 if (llvm::any_of(op.getInnerDimsPos(),
5831 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5833 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5835 return !PackOp::requirePaddingValue(
5836 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5837 op.getOuterDimsPerm(), op.getMixedTiles());
5844 bool changeNeeded =
false;
5845 srcShape.assign(packOp.getSourceType().getShape().begin(),
5846 packOp.getSourceType().getShape().end());
5847 destShape.assign(packOp.getDestType().getShape().begin(),
5848 packOp.getDestType().getShape().end());
5849 llvm::SmallSetVector<int64_t, 4> innerDims;
5850 innerDims.insert_range(packOp.getInnerDimsPos());
5852 if (!packOp.getOuterDimsPerm().empty())
5854 int srcRank = packOp.getSourceRank();
5855 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5856 if (innerDims.contains(i))
5860 if (!inverseOuterDimsPerm.empty())
5861 destPos = inverseOuterDimsPerm[srcPos];
5862 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5863 ShapedType::isDynamic(destShape[destPos])) {
5866 int64_t size = srcShape[srcPos];
5867 if (ShapedType::isDynamic(size))
5868 size = destShape[destPos];
5869 srcShape[srcPos] = size;
5870 destShape[destPos] = size;
5871 changeNeeded =
true;
5873 return changeNeeded;
5876LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5878 if (!packOp.hasPureTensorSemantics())
5882 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5883 if (unPackOp.getSourceType() == packOp.getDestType() &&
5884 !packOp.getPaddingValue() &&
5887 rewriter.
replaceOp(packOp, unPackOp.getSource());
5895 packOp.getPaddingValueMutable().clear();
5901 SmallVector<int64_t> srcShape, destShape;
5903 Location loc = packOp.getLoc();
5904 Value source = packOp.getSource();
5905 if (srcShape != packOp.getSourceType().getShape()) {
5906 auto newSrcType = packOp.getSourceType().clone(srcShape);
5908 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5910 Value dest = packOp.getDest();
5911 ShapedType originalResultType = packOp.getDestType();
5912 bool needUpdateDestType = (destShape != originalResultType.getShape());
5913 if (needUpdateDestType) {
5914 auto newDestType = packOp.getDestType().clone(destShape);
5916 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5919 packOp.getSourceMutable().assign(source);
5920 packOp.getDestMutable().assign(dest);
5921 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5924 if (needUpdateDestType) {
5926 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5927 packOp.getResult());
5936template <
typename PackOrUnpackOp>
5938 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5939 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5940 "Function meant for pack/unpack");
5945 int64_t numPackedDims = innerDimsPos.size();
5946 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5947 if (orderedDims != innerDimsPos) {
5953 int64_t packedRank = packedTensorType.getRank();
5963 return llvm::all_of(
5964 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5965 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5968bool PackOp::isLikePad() {
5969 auto packedTensorType =
5970 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5974::mlir::LogicalResult
5975PackOp::fold(FoldAdaptor adaptor,
5977 if (!hasPureTensorSemantics())
5979 std::optional<Attribute> paddingValue;
5980 if (
auto pad = adaptor.getPaddingValue())
5982 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5983 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5984 cast<TensorType>(getDestType()), paddingValue)) {
5985 results.push_back(reshapedSource);
6011 if (!op.hasPureTensorSemantics())
6032 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6033 op.getInnerDimsPos(), newMixedTileSizes,
6034 op.getPaddingValue(), op.getOuterDimsPerm());
6035 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6038 Value oldResult = op.getResult();
6039 Value newResult = newOp.getResult();
6042 ? tensor::CastOp::create(rewriter, op->getLoc(),
6043 oldResult.
getType(), newResult)
6056void UnPackOp::getAsmResultNames(
6058 if (!getResults().empty())
6059 setNameFn(getResult(),
"unpack");
6068 Type sourceType, destType, resultType;
6080 if (parser.parseInteger(value))
6082 outerDimsPermVec.push_back(value);
6092 SmallVector<int64_t> innerDimsPosVec;
6095 if (parser.parseInteger(value))
6097 innerDimsPosVec.push_back(value);
6109 for (
auto val : staticTilesAttr.
asArrayRef())
6110 staticTiles.push_back(val);
6127 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6130 "pack/unpack requires '->' and destination type");
6134 resultType = destType;
6140 if (!dynamicTiles.empty() &&
6145 result.addAttribute(
"static_inner_tiles",
6147 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6149 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6151 SmallVector<int32_t> segmentSizes = {
6152 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6153 result.addAttribute(
"operandSegmentSizes",
6157 result.addTypes(resultType);
6162void UnPackOp::print(OpAsmPrinter &p) {
6163 p <<
" " << getSource();
6165 if (!getOuterDimsPerm().empty()) {
6166 p <<
" outer_dims_perm = [";
6167 llvm::interleaveComma(getOuterDimsPerm(), p);
6171 p <<
" inner_dims_pos = [";
6172 llvm::interleaveComma(getInnerDimsPos(), p);
6175 p <<
" inner_tiles = ";
6178 p <<
" into " << getDest();
6181 {
"static_inner_tiles",
"inner_dims_pos",
6182 "outer_dims_perm",
"operandSegmentSizes"});
6184 p <<
" : " << getSource().getType();
6185 p <<
" -> " << getDest().getType();
6189UnPackOp::reifyResultShapes(OpBuilder &builder,
6198SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
6202SmallVector<int64_t> UnPackOp::getStaticTiles() {
6206ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
6207 ShapedType destType = getDestType();
6208 int64_t destRank = destType.getRank();
6209 return getSourceType().getShape().take_front(destRank);
6212SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
6213 auto innerDimsPos = getInnerDimsPos();
6214 SmallVector<int64_t> outerDims(getAllOuterDims());
6215 SmallVector<int64_t> res;
6218 SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
6220 if (!outerDimPermInv.empty())
6224 for (
auto index : innerDimsPos)
6225 res.push_back(outerDims[index]);
6230LogicalResult UnPackOp::verify() {
6235 if (!hasPureTensorSemantics())
6244void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
6245 Value dest, ArrayRef<int64_t> innerDimsPos,
6246 ArrayRef<OpFoldResult> innerTiles,
6247 ArrayRef<int64_t> outerDimsPerm) {
6248 assert(innerDimsPos.size() == innerTiles.size() &&
6249 "number of tile sizes specified must match the specified number of "
6250 "original dimensions to be tiled");
6251 SmallVector<int64_t> staticTileSizes;
6252 SmallVector<Value> dynamicTileSizes;
6254 build(builder, state, dest.
getType(), source, dest,
6255 outerDimsPerm.empty() ?
nullptr
6261Value UnPackOp::createDestinationTensor(OpBuilder &
b, Location loc,
6263 ArrayRef<OpFoldResult> innerTileSizes,
6264 ArrayRef<int64_t> innerDimsPos,
6265 ArrayRef<int64_t> outerDimsPerm) {
6266 AffineExpr sym0, sym1;
6268 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
6272 SmallVector<OpFoldResult> mixedSizes;
6273 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6275 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6276 if (srcType.isDynamicDim(i))
6277 mixedSizes.push_back(
6278 tensor::DimOp::create(
b, loc, source, i).getResult());
6280 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6282 if (!outerDimsPerm.empty()) {
6287 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6288 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6290 auto elemType = srcType.getElementType();
6291 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6294UnPackOp UnPackOp::createTransposedClone(OpBuilder &
b, Location loc,
6295 Value transposedSource,
6296 ArrayRef<int64_t> innerPermutation,
6297 ArrayRef<int64_t> outerPermutation) {
6299 *
this, innerPermutation, outerPermutation);
6300 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6301 metadata.innerDimsPos, metadata.innerTiles,
6302 metadata.outerDimsPerm);
6309 bool changeNeeded =
false;
6310 srcShape.assign(op.getSourceType().getShape().begin(),
6311 op.getSourceType().getShape().end());
6312 destShape.assign(op.getDestType().getShape().begin(),
6313 op.getDestType().getShape().end());
6314 llvm::SmallSetVector<int64_t, 4> innerDims;
6315 innerDims.insert_range(op.getInnerDimsPos());
6317 if (!op.getOuterDimsPerm().empty())
6319 int destRank = op.getDestRank();
6320 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6321 if (innerDims.contains(i))
6325 if (!inverseOuterDimsPerm.empty())
6326 srcPos = inverseOuterDimsPerm[destPos];
6327 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6328 ShapedType::isDynamic(destShape[destPos])) {
6331 int64_t size = srcShape[srcPos];
6332 if (ShapedType::isDynamic(size))
6333 size = destShape[destPos];
6334 srcShape[srcPos] = size;
6335 destShape[destPos] = size;
6336 changeNeeded =
true;
6338 return changeNeeded;
6341LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6344 if (!unPackOp.hasPureTensorSemantics())
6348 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6349 if (packOp.getSourceType() != unPackOp.getDestType())
6351 if (packOp.getPaddingValue() ||
6355 rewriter.
replaceOp(unPackOp, packOp.getSource());
6359 if (
auto dstStyleOp =
6360 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6361 auto destValue = cast<OpResult>(unPackOp.getDest());
6362 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6364 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6368 if (unPackOp->hasOneUse()) {
6369 auto extractSliceUser =
6370 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6371 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6372 OpBuilder::InsertionGuard g(rewriter);
6374 auto newDest = tensor::ExtractSliceOp::create(
6375 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6376 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6377 extractSliceUser.getMixedStrides());
6379 unPackOp.setDpsInitOperand(0, newDest);
6380 unPackOp.getResult().setType(newDest.
getType());
6382 rewriter.
replaceOp(extractSliceUser, unPackOp);
6388 SmallVector<int64_t> srcShape, destShape;
6390 Location loc = unPackOp.getLoc();
6391 Value source = unPackOp.getSource();
6392 if (srcShape != unPackOp.getSourceType().getShape()) {
6393 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6394 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6395 unPackOp.getSource());
6397 Value dest = unPackOp.getDest();
6398 if (destShape != unPackOp.getDestType().getShape()) {
6399 auto newDestType = unPackOp.getDestType().clone(destShape);
6400 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6401 unPackOp.getDest());
6403 UnPackOp newOp = UnPackOp::create(
6404 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6405 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6407 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6414bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6416 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6421 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6422 SmallVector<int64_t> outerShapeWithoutTranspose =
6424 SmallVector<bool> areOuterDimsTiled(outerShapeWithoutTranspose.size(),
false);
6425 for (
auto [pos, tileSize] :
6426 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6427 areOuterDimsTiled[pos] =
true;
6428 if (unpackedTypeAfterFold.isDynamicDim(pos))
6430 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6432 if (ShapedType::isDynamic(tileSize))
6434 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6435 unpackedTypeAfterFold.getDimSize(pos);
6436 if (paddingSize >= tileSize)
6440 for (int64_t pos = 0, e = outerShapeWithoutTranspose.size(); pos < e; ++pos) {
6441 if (areOuterDimsTiled[pos])
6443 int64_t dim = outerShapeWithoutTranspose[pos];
6444 if (ShapedType::isDynamic(dim))
6446 if (dim != unpackedTypeAfterFold.getDimSize(pos))
6452bool UnPackOp::isLikeUnPad() {
6453 ShapedType packedTensorType = getSourceType();
6457::mlir::LogicalResult
6458UnPackOp::fold(FoldAdaptor adaptor,
6459 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6461 if (!hasPureTensorSemantics())
6464 if (OpFoldResult reshapedSource = reshapeConstantSource(
6465 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6466 cast<TensorType>(getResult().
getType()))) {
6467 results.push_back(reshapedSource);
6493 if (!op.hasPureTensorSemantics())
6502 Value sourceTensor = newOperands[0];
6506 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6512 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6513 newOperands[1], op.getInnerDimsPos(),
6514 newMixedTileSizes, op.getOuterDimsPerm());
6515 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6518 Value oldResult = op.getResult();
6519 Value newResult = newOp.getResult();
6522 ? tensor::CastOp::create(rewriter, op->getLoc(),
6523 oldResult.
getType(), newResult)
6537 utils::IteratorType::reduction, utils::IteratorType::parallel,
6538 utils::IteratorType::parallel, utils::IteratorType::reduction};
6541SmallVector<AffineMap>
6542BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6543 AffineExpr d0, d1, d2, d3;
6544 SmallVector<AffineMap> indexingMaps;
6546 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6547 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6549 return indexingMaps;
6552bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6553 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6556 if (maps.size() != 3)
6561 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6562 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6563 (*positions)[2] == SmallVector<int64_t>{1, 2};
6565unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6567std::string BatchReduceMatmulOp::getLibraryCallName() {
6573bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6574 SmallVector<AffineMap, 3> defaultMaps =
6576 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6577 return defaultMaps != explicitMaps;
6587bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6590 "Expected less than 3 result dim expr.");
6591 bool isValid =
false;
6592 enum Indices { batchPos, mPos, nPos, kPos };
6594 AffineExpr expr = bcastMap.
getResult(0);
6597 AffineExpr expr0 = bcastMap.
getResult(0);
6598 AffineExpr expr1 = bcastMap.
getResult(1);
6603 : ((expr0.isFunctionOfDim(batchPos) &&
6604 expr1.isFunctionOfDim(kPos)) ||
6605 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6610void BatchReduceMatmulOp::regionBuilder(
6611 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
6614 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6619 "BatchReduceMatmulOp regionBuilder expects 3 args");
6620 RegionBuilderHelper helper(
b, block);
6621 SmallVector<Value> yields;
6625 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6627 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6629 helper.buildBinaryFn(BinaryFn::mul, castValA, castValB,
emitError);
6630 if (!castValA || !castValB || !mulVal)
6633 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6636 yields.push_back(addVal);
6637 helper.yieldOutputs(yields);
6640ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6641 OperationState &
result) {
6642 SmallVector<Attribute, 3> indexingMapsAttr;
6653 if (!isa<AffineMapAttr>(mapAttr)) {
6655 "expected affine map attribute");
6657 indexingMapsAttr.push_back(mapAttr);
6667 if (indexingMapsAttr.empty()) {
6668 indexingMapsAttr = llvm::map_to_vector(
6669 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6670 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6672 result.addAttribute(
"indexing_maps",
6674 return ::parseNamedStructuredOp(parser,
result,
6675 BatchReduceMatmulOp::getNumRegionArgs(),
6676 BatchReduceMatmulOp::getRegionBuilder());
6679void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
6680 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6681 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6682 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
6684 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6685 p <<
" indexing_maps = [";
6686 llvm::interleaveComma(getIndexingMaps(), p,
6691 SmallVector<StringRef, 3> elidedAttrs = {
6692 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6698LogicalResult BatchReduceMatmulOp::verify() {
6701 if (!hasUserDefinedMaps())
6704 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6710LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6711 SmallVectorImpl<OpFoldResult> &) {
6714void BatchReduceMatmulOp::getEffects(
6715 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
6717 if (hasPureTensorSemantics())
6733void LinalgDialect::getCanonicalizationPatterns(
6742 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > defaultIndexingMaps)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override