40#include "llvm/ADT/DenseMap.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/SetOperations.h"
43#include "llvm/ADT/SmallVector.h"
44#include "llvm/ADT/StringSet.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/Support/FormatVariadic.h"
47#include "llvm/Support/InterleavedRange.h"
48#include "llvm/Support/LogicalResult.h"
49#include "llvm/Support/MathExtras.h"
50#include "llvm/Support/raw_ostream.h"
60 auto type = cast<ShapedType>(v.
getType());
61 if (!type.isDynamicDim(dim))
66 .Case<RankedTensorType>([&](RankedTensorType t) ->
Value {
67 return tensor::DimOp::create(builder, loc, v, dim);
69 .Case<MemRefType>([&](MemRefType t) ->
Value {
70 return memref::DimOp::create(builder, loc, v, dim);
81 .Case<RankedTensorType>([&](RankedTensorType t) ->
Operation * {
82 return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
85 .Case<MemRefType>([&](MemRefType type) ->
Operation * {
86 return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
98 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
99 return b.createOrFold<memref::DimOp>(loc, source, dim);
100 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
101 return b.createOrFold<tensor::DimOp>(loc, source, dim);
102 llvm_unreachable(
"Expected MemRefType or TensorType");
107 auto shapedType = llvm::cast<ShapedType>(source.
getType());
108 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110 return b.getIndexAttr(shapedType.getDimSize(dim));
133 for (
auto containers : {inputTypes, outputTypes}) {
134 for (
auto t : containers) {
146 opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
162 std::optional<TypeRange> resultTensorTypes,
169 if (!resultTensorTypes)
170 copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
171 llvm::IsaPred<RankedTensorType>);
179 "operandSegmentSizes",
180 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
181 static_cast<int32_t>(outputs.size())}));
191 std::optional<TypeRange> resultTensorTypes,
198 indexingMapsAttrVal =
200 return AffineMapAttr::get(map);
202 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
204 attributes, regionBuilder);
208 std::optional<TypeRange> resultTensorTypes,
215 indexingMapsAttrVal =
217 return AffineMapAttr::get(map);
219 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
221 attributes, regionBuilder);
225 std::optional<TypeRange> resultTensorTypes,
232 indexingMapsAttrVal =
234 return AffineMapAttr::get(map);
236 state.
addAttribute(
"indexing_maps",
b.getArrayAttr(indexingMapsAttrVal));
238 attributes, regionBuilder);
247 bool addOperandSegmentSizes =
true) {
248 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
277 if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
279 parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
283 if (addOperandSegmentSizes) {
290 if (
result.propertiesAttr) {
292 attrs.
append(
"operandSegmentSizes",
294 {static_cast<int32_t>(inputsOperands.size()),
295 static_cast<int32_t>(outputsOperands.size())}));
298 result.addAttribute(
"operandSegmentSizes",
300 {static_cast<int32_t>(inputsOperands.size()),
301 static_cast<int32_t>(outputsOperands.size())}));
304 if (!
result.propertiesAttr) {
305 std::optional<RegisteredOperationName> info =
306 result.name.getRegisteredInfo();
308 if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
309 return parser.emitError(attrsLoc)
310 <<
"'" << result.name.getStringRef() <<
"' op ";
321 p <<
" ins(" << inputs <<
" : " << inputs.
getTypes() <<
")";
322 if (!outputs.empty())
323 p <<
" outs(" << outputs <<
" : " << outputs.
getTypes() <<
")";
334 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
337 llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated "
338 "region expects {0} args, got {1}",
339 numRegionArgs, inputTypes.size() + outputTypes.size()));
345 opBuilder, region, inputTypes, outputTypes, attrs,
364 unsigned numRegionArgs,
381 result.addTypes(outputTensorsTypes);
383 std::unique_ptr<Region> region = std::make_unique<Region>();
385 outputTypes,
result.attributes.getAttrs(),
388 result.addRegion(std::move(region));
395 if (resultTypes.empty())
440class RegionBuilderHelper {
442 RegionBuilderHelper(OpBuilder &builder,
Block &block)
443 : builder(builder), block(block) {}
446 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
448 if (!isFloatingPoint(arg)) {
450 emitError() <<
"unsupported non numeric type";
453 llvm_unreachable(
"unsupported non numeric type");
455 OpBuilder::InsertionGuard g(builder);
456 builder.setInsertionPointToEnd(&block);
459 return math::ExpOp::create(builder, arg.
getLoc(), arg);
461 return math::LogOp::create(builder, arg.
getLoc(), arg);
463 return math::AbsFOp::create(builder, arg.
getLoc(), arg);
465 return math::CeilOp::create(builder, arg.
getLoc(), arg);
467 return math::FloorOp::create(builder, arg.
getLoc(), arg);
469 return arith::NegFOp::create(builder, arg.
getLoc(), arg);
470 case UnaryFn::reciprocal: {
471 Attribute oneAttr = builder.getOneAttr(arg.
getType());
472 auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
473 ::cast<TypedAttr>(oneAttr));
474 return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
477 return math::RoundOp::create(builder, arg.
getLoc(), arg);
479 return math::SqrtOp::create(builder, arg.
getLoc(), arg);
481 return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
482 case UnaryFn::square:
483 return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
485 return math::TanhOp::create(builder, arg.
getLoc(), arg);
487 return math::ErfOp::create(builder, arg.
getLoc(), arg);
490 emitError() <<
"unsupported unary function";
493 llvm_unreachable(
"unsupported unary function");
500 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
502 bool allComplex = isComplex(arg0) && isComplex(arg1);
503 bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
504 bool allInteger = isInteger(arg0) && isInteger(arg1);
507 if (!allComplex && !allFloatingPoint && !allInteger) {
510 <<
"Cannot build binary Linalg operation: expects allComplex, "
511 "allFloatingPoint, or allInteger, got "
515 llvm_unreachable(
"unsupported non numeric type");
517 OpBuilder::InsertionGuard g(builder);
518 builder.setInsertionPointToEnd(&block);
522 return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
523 if (allFloatingPoint)
524 return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
526 return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
527 return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
530 return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
531 if (allFloatingPoint)
532 return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
535 emitError() <<
"unsupported operation: sub with bools";
538 llvm_unreachable(
"unsupported operation: sub with bools");
540 return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
543 return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
544 if (allFloatingPoint)
545 return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
547 return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
548 return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
551 return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
552 if (allFloatingPoint)
553 return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
556 emitError() <<
"unsupported operation: div with bools";
559 llvm_unreachable(
"unsupported operation: div with bools");
561 return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
562 case BinaryFn::div_unsigned:
563 if (!allInteger || allBool) {
565 emitError() <<
"unsupported operation: unsigned div not on uint";
568 llvm_unreachable(
"unsupported operation: unsigned div not on uint");
570 return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
571 case BinaryFn::max_signed:
573 if (allFloatingPoint)
574 return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
575 return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
576 case BinaryFn::min_signed:
578 if (allFloatingPoint)
579 return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
580 return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
581 case BinaryFn::max_unsigned:
583 if (!allInteger || allBool) {
585 emitError() <<
"unsupported operation: unsigned max not on uint";
588 llvm_unreachable(
"unsupported operation: unsigned max not on uint");
590 return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
591 case BinaryFn::min_unsigned:
593 if (!allInteger || allBool) {
595 emitError() <<
"unsupported operation: unsigned min not on uint";
598 llvm_unreachable(
"unsupported operation: unsigned min not on uint");
600 return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
602 assert(allFloatingPoint);
603 return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
606 emitError() <<
"unsupported binary function";
609 llvm_unreachable(
"unsupported binary function");
613 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
617 bool tailFloatingPoint =
618 isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
619 bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
620 OpBuilder::InsertionGuard g(builder);
621 builder.setInsertionPointToEnd(&block);
623 case TernaryFn::select:
624 if (!headBool && !(tailFloatingPoint || tailInteger))
625 llvm_unreachable(
"unsupported non numeric type");
626 return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
629 emitError() <<
"unsupported ternary function";
632 llvm_unreachable(
"unsupported ternary function");
636 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
639 case TypeFn::cast_signed:
640 return cast(toType, operand,
false);
641 case TypeFn::cast_unsigned:
642 return cast(toType, operand,
true);
645 emitError() <<
"unsupported type conversion function";
648 llvm_unreachable(
"unsupported type conversion function");
652 OpBuilder::InsertionGuard g(builder);
653 builder.setInsertionPointToEnd(&block);
654 Location loc = builder.getUnknownLoc();
655 YieldOp::create(builder, loc, values);
658 Value constant(
const std::string &value) {
659 OpBuilder::InsertionGuard g(builder);
660 builder.setInsertionPointToEnd(&block);
661 Location loc = builder.getUnknownLoc();
662 Attribute valueAttr =
parseAttribute(value, builder.getContext());
663 return arith::ConstantOp::create(builder, loc,
664 ::cast<TypedAttr>(valueAttr));
667 Value index(int64_t dim) {
668 OpBuilder::InsertionGuard g(builder);
669 builder.setInsertionPointToEnd(&block);
670 return IndexOp::create(builder, builder.getUnknownLoc(), dim);
673 Type getIntegerType(
unsigned width) {
674 return IntegerType::get(builder.getContext(), width);
677 Type getFloat32Type() {
return Float32Type::get(builder.getContext()); }
678 Type getFloat64Type() {
return Float64Type::get(builder.getContext()); }
685 Value cast(Type toType, Value operand,
bool isUnsignedCast) {
686 OpBuilder::InsertionGuard g(builder);
687 builder.setInsertionPointToEnd(&block);
688 auto loc = operand.
getLoc();
689 if (isa<UnknownLoc>(loc)) {
699 bool isComplex(Value value) {
700 return llvm::isa<ComplexType>(value.
getType());
702 bool isFloatingPoint(Value value) {
703 return llvm::isa<FloatType>(value.
getType());
705 bool isInteger(Value value) {
706 return llvm::isa<IntegerType>(value.
getType());
722 using OpRewritePattern<CopyOp>::OpRewritePattern;
723 LogicalResult matchAndRewrite(CopyOp copyOp,
724 PatternRewriter &rewriter)
const override {
725 if (copyOp.getInputs() != copyOp.getOutputs())
727 if (copyOp.hasPureBufferSemantics())
730 rewriter.
replaceOp(copyOp, copyOp.getInputs());
740 results.
add<EraseSelfCopy>(context);
753template <
typename TensorReshapeOp>
755 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
756 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
757 PatternRewriter &rewriter)
const override {
758 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
762 Location loc = oldFill.getLoc();
763 TensorReshapeOp newInit;
764 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
766 newInit = TensorReshapeOp::create(
767 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
768 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
769 reshapeOp.getStaticOutputShape());
771 newInit = TensorReshapeOp::create(
772 rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
773 reshapeOp.getReassociation());
786 LogicalResult matchAndRewrite(tensor::PadOp padOp,
787 PatternRewriter &rewriter)
const override {
788 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
794 Value padValue = padOp.getConstantPaddingValue();
795 if (!padValue || fillOp.value() != padValue)
801 padOp,
"failed to reify tensor.pad op result shape");
804 tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
805 padOp.getResultType().getElementType());
807 FillOp::create(rewriter, fillOp.getLoc(),
ValueRange{padValue},
810 if (
replacement.getType() != padOp.getResultType()) {
811 replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
822struct FoldInsertPadIntoFill :
public OpRewritePattern<tensor::InsertSliceOp> {
825 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
826 PatternRewriter &rewriter)
const override {
827 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
831 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
836 Value firstDest = insertOp.getDest();
837 while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
838 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
843 bool disjoint =
false;
844 for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
847 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
848 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
849 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
853 int64_t prevStart = prevOp.getStaticOffset(i);
854 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
855 prevOp.getStaticStride(i);
856 int64_t nextStart = insertOp.getStaticOffset(i);
857 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
858 insertOp.getStaticStride(i);
859 if (prevEnd < nextStart || nextEnd < prevStart) {
867 firstDest = prevOp.getDest();
878 Value padValue = srcPadOp.getConstantPaddingValue();
879 if (!padValue || dstFillOp.value() != padValue)
882 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
883 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
885 Location loc = insertOp.getLoc();
888 AffineExpr sym0, sym1;
894 SmallVector<OpFoldResult, 4> newOffsets;
895 for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
896 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
897 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
900 RankedTensorType srcPadType = srcPadOp.getSourceType();
901 SmallVector<OpFoldResult, 4> newSizes;
902 for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
903 if (srcPadType.isDynamicDim(i)) {
905 tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
908 newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
913 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
914 newSizes, insertOp.getMixedStrides());
920struct FoldFillWithTensorExtract :
public OpRewritePattern<tensor::ExtractOp> {
922 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
924 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
925 PatternRewriter &rewriter)
const override {
928 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
933 Value extractedScalar = fillOp.getInputs()[0];
936 rewriter.
replaceOp(extractOp, extractedScalar);
944static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
945 linalg::PackOp packOp) {
946 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
950 if (
auto paddingValue = packOp.getPaddingValue())
954 Value packOpDest = packOp.getDest();
958 return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
965 FoldFillWithPack(MLIRContext *context)
966 : OpRewritePattern<linalg::PackOp>(context) {}
968 LogicalResult matchAndRewrite(linalg::PackOp packOp,
969 PatternRewriter &rewriter)
const override {
970 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
973 rewriter.
replaceOp(packOp, fillOp.value().result());
980 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
982 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
983 PatternRewriter &rewriter)
const override {
984 if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
987 copyOp.getOutputs());
990 if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
992 fillOp.getOutputs());
1001 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
1003 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1004 PatternRewriter &rewriter)
const override {
1005 if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
1007 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
1008 transposeOp.getDpsInitOperand(0)->get());
1020 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1021 PatternRewriter &rewriter)
const override {
1022 auto concatOperands = concatOp.getInputs();
1023 if (concatOperands.empty()) {
1027 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1032 OpFoldResult firstFillVal =
1035 SmallVector<Value> allOuts;
1036 allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1038 auto isDefinedByCompatibleFillOp = [&](Value v) ->
bool {
1039 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1044 OpFoldResult fillVal =
1046 if (fillVal != firstFillVal)
1049 allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1052 if (!llvm::all_of(concatOperands.drop_front(),
1053 isDefinedByCompatibleFillOp)) {
1055 concatOp,
"not all operands are defined by a compatible fill op");
1058 Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1059 concatOp.getDim(), allOuts);
1061 concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
1070 results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1071 FoldFillWithPack, FoldFillWithPad,
1072 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1073 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1074 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1087 for (
ValueRange container : {inputs, outputs}) {
1088 for (
Value v : container) {
1089 Type t = v.getType();
1090 blockArgTypes.push_back(
1092 blockArgLocs.push_back(v.getLoc());
1098 builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
1102void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
1104 for (Value v : getRegionInputArgs())
1106 for (Value v : getRegionOutputArgs())
1107 setNameFn(v,
"out");
1110void GenericOp::build(
1111 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1113 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1115 ArrayRef<NamedAttribute> attributes) {
1116 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1117 iteratorTypes, doc, libraryCall);
1118 result.addAttributes(attributes);
1121 inputs, outputs, bodyBuild);
1124void GenericOp::build(
1125 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1127 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1128 StringRef libraryCall,
1130 ArrayRef<NamedAttribute> attributes) {
1131 build(builder,
result, resultTensorTypes, inputs, outputs,
1135 [&](utils::IteratorType iter) -> mlir::Attribute {
1136 return IteratorTypeAttr::get(builder.getContext(), iter);
1139 libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
1140 bodyBuild, attributes);
1143void GenericOp::build(
1145 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1147 StringRef libraryCall,
1149 ArrayRef<NamedAttribute> attributes) {
1151 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1154void GenericOp::build(
1156 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157 ArrayRef<utils::IteratorType> iteratorTypes,
1159 ArrayRef<NamedAttribute> attributes) {
1160 build(builder,
result, inputs, outputs, indexingMaps, iteratorTypes,
1162 "", bodyBuild, attributes);
1165void GenericOp::build(
1166 OpBuilder &builder, OperationState &
result,
TypeRange resultTensorTypes,
1168 ArrayRef<utils::IteratorType> iteratorTypes,
1170 ArrayRef<NamedAttribute> attributes) {
1171 build(builder,
result, resultTensorTypes, inputs, outputs, indexingMaps,
1174 "", bodyBuild, attributes);
1177void GenericOp::print(OpAsmPrinter &p) {
1181 auto genericAttrNames = linalgTraitAttrNames();
1183 llvm::StringSet<> genericAttrNamesSet;
1184 genericAttrNamesSet.insert_range(genericAttrNames);
1185 SmallVector<NamedAttribute, 8> genericAttrs;
1186 for (
auto attr : (*this)->getAttrs()) {
1187 if (attr.getName() == getIteratorTypesAttrName()) {
1188 auto iteratorTypes =
1189 llvm::cast<ArrayAttr>(attr.getValue())
1190 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1195 SmallVector<Attribute> iteratorTypeNames =
1196 llvm::to_vector(llvm::map_range(
1197 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1198 return StringAttr::get(
getContext(), stringifyIteratorType(t));
1201 genericAttrs.emplace_back(
1202 getIteratorTypesAttrName(),
1203 ArrayAttr::get(
getContext(), iteratorTypeNames));
1204 }
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1205 genericAttrs.push_back(attr);
1208 if (!genericAttrs.empty()) {
1209 auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
1210 p << genericDictAttr;
1216 genericAttrNames.push_back(
"operandSegmentSizes");
1217 genericAttrNamesSet.insert(genericAttrNames.back());
1219 bool hasExtraAttrs =
false;
1220 for (NamedAttribute n : (*this)->getAttrs()) {
1221 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1224 if (hasExtraAttrs) {
1231 if (!getRegion().empty()) {
1240ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
1241 DictionaryAttr dictAttr;
1249 result.attributes.assign(dictAttr.getValue().begin(),
1250 dictAttr.getValue().end());
1256 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1257 result.attributes.get(getIteratorTypesAttrName(
result.name)));
1258 if (!iteratorTypes) {
1259 return parser.
emitError(attributeLocation)
1260 <<
"expected " << getIteratorTypesAttrName(
result.name)
1261 <<
" array attribute";
1264 SmallVector<Attribute> iteratorTypeAttrs;
1266 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1267 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1268 if (!maybeIteratorType.has_value())
1270 <<
"unexpected iterator_type (" << s <<
")";
1272 iteratorTypeAttrs.push_back(
1273 IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
1275 result.attributes.set(getIteratorTypesAttrName(
result.name),
1279 SmallVector<Type, 1> inputTypes, outputTypes;
1289 std::unique_ptr<Region> region = std::make_unique<Region>();
1292 result.addRegion(std::move(region));
1298 SmallVector<Type, 1> outputTensorsTypes;
1301 result.addTypes(outputTensorsTypes);
1309 LinalgOp linalgOp) {
1310 for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1311 if (!llvm::isa<MemRefType>(operand.
getType()))
1313 effects.emplace_back(
1318 for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1319 if (!llvm::isa<MemRefType>(operand.get().
getType()))
1321 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1332void GenericOp::getEffects(
1333 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1342 if (!linalgOp.hasPureTensorSemantics())
1360template <
typename OpTy>
1361struct EraseIdentityLinalgOp :
public OpRewritePattern<OpTy> {
1362 using OpRewritePattern<OpTy>::OpRewritePattern;
1364 LogicalResult matchAndRewrite(OpTy linalgOp,
1365 PatternRewriter &rewriter)
const override {
1367 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1372 Block &body = linalgOp->getRegion(0).front();
1373 if (!llvm::hasSingleElement(body))
1375 auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
1380 if (linalgOp.hasPureBufferSemantics()) {
1381 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1382 linalgOp.getDpsInputOperand(0)->get() !=
1383 linalgOp.getDpsInitOperand(0)->get()) {
1385 linalgOp,
"expected single input and output to be the same value");
1388 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1389 if (!yieldArg || yieldArg.getOwner() != &body) {
1391 "cannot fold fill-like op");
1398 if (!linalgOp.hasPureTensorSemantics()) {
1400 linalgOp,
"mixed semantics is not supported yet");
1405 SmallVector<Value> returnedArgs;
1406 for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1407 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1408 if (!yieldArg || yieldArg.getOwner() != &body)
1410 unsigned argumentNumber = yieldArg.getArgNumber();
1411 Value returnedArg = linalgOp->getOperand(argumentNumber);
1412 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1415 Type returnType = returnedArg.
getType();
1416 if (returnType != resultType) {
1421 returnedArg = sparse_tensor::ConvertOp::create(
1422 rewriter, linalgOp.getLoc(), resultType, returnedArg);
1424 if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
1427 returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1428 resultType, returnedArg);
1431 returnedArgs.push_back(returnedArg);
1434 if (returnedArgs.size() != linalgOp->getNumResults())
1436 rewriter.
replaceOp(linalgOp, returnedArgs);
1443void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1444 MLIRContext *context) {
1445 results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
1448LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1467 for (
Type outputType : outputTypes) {
1468 if (llvm::isa<RankedTensorType>(outputType))
1469 result.addTypes(outputType);
1473 if (parseAttrsFn && failed(parseAttrsFn(parser,
result.attributes)))
1482void MapOp::getAsmBlockArgumentNames(Region ®ion,
1484 for (Value v : getRegionInputArgs())
1486 for (Value v : getRegionOutputArgs())
1487 setNameFn(v,
"init");
1490void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
1491 if (!getResults().empty())
1492 setNameFn(getResults().front(),
"mapped");
1498 ArrayRef<NamedAttribute> attributes) {
1500 result.addAttributes(attributes);
1503 Type initType = init.
getType();
1504 if (llvm::isa<RankedTensorType>(initType))
1505 result.addTypes(initType);
1509 inputs, {init}, bodyBuild);
1516 bool initFirst =
false,
bool mapInit =
true) {
1520 b.setInsertionPointToStart(&block);
1521 for (
auto &operand : operands) {
1523 llvm::cast<ShapedType>(operand.
getType()).getElementType(),
1531 payloadOpOperands.push_back(block.
getArguments().back());
1532 for (
const auto &arg : block.
getArguments().drop_back())
1533 payloadOpOperands.push_back(arg);
1542 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1548ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
1549 std::optional<OperationName> payloadOpName;
1550 NamedAttrList payloadOpAttrs;
1553 if (
failed(operationName))
1557 payloadOpName = operationName.value();
1565 if (payloadOpName.has_value()) {
1566 if (!
result.operands.empty())
1568 payloadOpAttrs, ArrayRef(
result.operands),
false,
1573 SmallVector<OpAsmParser::Argument> regionArgs;
1578 Region *body =
result.addRegion();
1586 bool mapInit =
true) {
1588 if (initFirst && !mapInit)
1612 for (
const auto &[operand, bbArg] :
1614 if (bbArg != operand)
1618 for (
const auto &[operand, bbArg] :
1621 if (bbArg != operand)
1628 return yieldOp.getNumOperands() == 1 &&
1629 yieldOp.getOperand(0).getDefiningOp() &&
1630 yieldOp.getOperand(0).getDefiningOp() == &payload;
1635 std::string attrToElide;
1637 for (
const auto &attr : payloadOp->
getAttrs()) {
1639 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1640 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1641 attrToElide = attr.getName().str();
1642 elidedAttrs.push_back(attrToElide);
1650void MapOp::print(OpAsmPrinter &p) {
1651 Block *mapper = getBody();
1661 if (!useShortForm) {
1667 [&](
auto arg) { p.printRegionArgument(arg); });
1675LogicalResult MapOp::verify() {
1676 auto *bodyBlock = getBody();
1677 auto blockArgs = bodyBlock->getArguments();
1681 if (getInputs().size() + 1 != blockArgs.size())
1682 return emitOpError() <<
"expects number of operands to match the arity of "
1684 << getInputs().size() + 1 <<
" and "
1685 << blockArgs.size();
1688 for (
const auto &[bbArgType, inputArg] :
1689 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1690 auto inputElemType =
1691 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1692 if (bbArgType != inputElemType) {
1693 return emitOpError() <<
"expected element type of input " << inputElemType
1694 <<
" to match bbArg type " << bbArgType;
1699 auto outputShape = getInit().getType().getShape();
1700 for (Type inputArgType :
TypeRange{getInputs()}) {
1701 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1702 if (inputElemShape != outputShape) {
1703 return emitOpError() <<
"expected shape of input (" << inputElemShape
1704 <<
") to match shape of output (" << outputShape
1712SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1713 int64_t rank = getInit().getType().getRank();
1714 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1719 int64_t rank = getInit().getType().getRank();
1720 int64_t numIndexingMaps = getOperands().size();
1725void MapOp::getEffects(
1726 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1739void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
1741 for (Value v : getRegionInputArgs())
1743 for (Value v : getRegionOutputArgs())
1744 setNameFn(v,
"init");
1747void ReduceOp::getAsmResultNames(
1749 if (!getResults().empty())
1750 setNameFn(getResults().front(),
"reduced");
1753void ReduceOp::build(
1755 ValueRange inits, ArrayRef<int64_t> dimensions,
1757 ArrayRef<NamedAttribute> attributes) {
1759 result.addAttributes(attributes);
1762 for (Value init : inits) {
1763 Type initType = init.
getType();
1764 if (llvm::isa<RankedTensorType>(initType))
1765 result.addTypes(initType);
1770 inputs, inits, bodyBuild);
1773SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1775 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1776 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1777 utils::IteratorType::parallel);
1778 for (int64_t reductionDim : getDimensions())
1779 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1780 return iteratorTypes;
1785 llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
1786 SmallVector<AffineMap> affineMaps(
1789 AffineMap resultMap =
1792 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1793 affineMaps.push_back(resultMap);
1794 return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
1797void ReduceOp::getEffects(
1798 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1809 StringRef attributeName) {
1817ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
1818 std::optional<OperationName> payloadOpName;
1819 NamedAttrList payloadOpAttrs;
1822 if (
failed(operationName))
1826 payloadOpName = operationName.value();
1832 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1837 if (payloadOpName.has_value()) {
1839 ArrayRef(
result.operands),
true);
1841 SmallVector<OpAsmParser::Argument> regionArgs;
1847 Region *body =
result.addRegion();
1857 p <<
' ' << attributeName <<
" = [" << attributeValue <<
"] ";
1860void ReduceOp::print(OpAsmPrinter &p) {
1861 Block *mapper = getBody();
1870 if (!useShortForm) {
1876 [&](
auto arg) { p.printRegionArgument(arg); });
1884LogicalResult ReduceOp::verify() {
1885 ArrayRef<int64_t> dimensionsRef = getDimensions();
1887 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1890 return emitOpError() <<
"expects all inputs to have the same shapes. "
1891 "Shape at input-index "
1893 <<
" is not equal to the shape at input-index 0.";
1896 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1899 return emitOpError() <<
"expects all outputs to have the same shapes. "
1900 "Shape at output-index "
1902 <<
" is not equal to the shape at output-index 0.";
1905 auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
1906 auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
1909 for (int64_t dimension : dimensionsRef) {
1910 if (dimension < 0 || dimension >= inputType.getRank()) {
1912 <<
"dimensions for reduction should be in the range [0, "
1913 << inputType.getRank() - 1 <<
"].";
1915 dimensionsToReduce.insert(dimension);
1918 auto inputDims = inputType.getShape();
1919 auto initDims = initType.getShape();
1922 SmallVector<int64_t> reducedInputDims;
1923 for (
const auto &en : llvm::enumerate(inputDims)) {
1924 if (!dimensionsToReduce.count(en.index()))
1925 reducedInputDims.push_back(en.value());
1928 if (reducedInputDims.size() !=
static_cast<size_t>(initType.getRank())) {
1929 return emitOpError() <<
"number of dimensions after reduction "
1930 << reducedInputDims.size()
1931 <<
" doesn't match the init rank "
1932 << initType.getRank();
1935 if (reducedInputDims != initDims)
1936 return emitOpError() <<
"init dimensions [" << initDims
1937 <<
"] doesn't match input dimensions after reduction ["
1938 << reducedInputDims <<
"]";
1940 Block *block = getBody();
1943 <<
"mismatching number of operands and block arguments";
1946 for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
1947 Type inputElementType =
1948 llvm::cast<ShapedType>(input.getType()).getElementType();
1949 if (inputElementType != bbArg.getType())
1951 <<
"input element type " << inputElementType
1952 <<
" does not match corresponding block argument type "
1957 for (
auto [output, bbArg] : llvm::zip(
1958 getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
1959 auto outputElementType =
1960 llvm::cast<ShapedType>(output.getType()).getElementType();
1961 if (outputElementType != bbArg.getType())
1963 <<
"output element type " << outputElementType
1964 <<
" does not match corresponding block argument type "
1980 linalg::YieldOp::create(
b, loc, args[0]);
1984void TransposeOp::build(::mlir::OpBuilder &builder,
1985 ::mlir::OperationState &
result, Value input, Value init,
1987 ArrayRef<NamedAttribute> attributes) {
1988 result.addOperands(input);
1989 result.addOperands(init);
1990 result.addAttribute(getPermutationAttrName(
result.name), permutation);
1991 result.addAttributes(attributes);
1994 Type initType = init.
getType();
1995 if (llvm::isa<RankedTensorType>(initType))
1996 result.addTypes(initType);
2002void TransposeOp::build(::mlir::OpBuilder &builder,
2003 ::mlir::OperationState &
result, Value input, Value init,
2004 ArrayRef<int64_t> permutation,
2005 ArrayRef<NamedAttribute> attributes) {
2010ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
2012 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2024void TransposeOp::getAsmResultNames(
2026 if (!getResults().empty())
2027 setNameFn(getResults().front(),
"transposed");
2030void TransposeOp::print(OpAsmPrinter &p) {
2036LogicalResult TransposeOp::verify() {
2037 ArrayRef<int64_t> permutationRef = getPermutation();
2042 auto inputType = getInput().getType();
2043 auto initType = getInit().getType();
2045 int64_t rank = inputType.getRank();
2051 if (rank !=
static_cast<int64_t
>(permutationRef.size()))
2052 return emitOpError() <<
"size of permutation " << permutationRef.size()
2053 <<
" does not match the argument rank " << rank;
2055 auto inputDims = inputType.getShape();
2056 auto initDims = initType.getShape();
2058 for (int64_t i = 0; i < rank; ++i) {
2059 int64_t inputDim = inputDims[permutationRef[i]];
2060 int64_t initDim = initDims[i];
2062 if (inputDim != initDim) {
2063 return emitOpError() <<
"dim(result, " << i <<
") = " << initDim
2064 <<
" doesn't match dim(input, permutation[" << i
2065 <<
"]) = " << inputDim;
2072SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2073 int64_t rank = getInit().getType().getRank();
2074 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2077ArrayAttr TransposeOp::getIndexingMaps() {
2079 int64_t rank = getInit().getType().getRank();
2082 llvm::to_vector_of<unsigned>(getPermutation()),
getContext())),
2086void TransposeOp::getEffects(
2087 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2096LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2097 SmallVectorImpl<OpFoldResult> &
result) {
2099 if (!isa<TensorType>(getInput().
getType()))
2103 if (getPermutation().empty()) {
2104 result.push_back(getInput());
2109 result.push_back(getInput());
2122 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2123 if (!defTransposeOp)
2128 foldedPerms.reserve(perms.size());
2130 foldedPerms.push_back(defPerms[perm]);
2133 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2147 Value input = transposeOp.getInput();
2148 BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
2159 unsigned dimensionSize = dimensions.size();
2160 for (
unsigned i = 0; i < dimensionSize; ++i)
2161 resultDimensions.push_back(invertPerm[dimensions[i]]);
2164 Value broadcastInput = broadcastOp.getInput();
2165 Location loc = transposeOp.getLoc();
2168 auto broadcastInputTy =
2169 mlir::cast<RankedTensorType>(broadcastInput.
getType());
2170 unsigned inputRank = broadcastInputTy.getRank();
2171 for (
unsigned i = 0; i < inputRank; ++i) {
2172 if (broadcastInputTy.isDynamicDim(i)) {
2173 dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2176 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2177 broadcastInputTy.getDimSize(i)));
2182 Value transposeInit = tensor::EmptyOp::create(
2183 rewriter, transposeOp.getLoc(), transposeResultShapes,
2184 broadcastInputTy.getElementType());
2187 Value transposeResult =
2188 TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2189 transposeInit, resultPerms)
2192 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2197void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2198 MLIRContext *context) {
2199 results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2206void BroadcastOp::build(::mlir::OpBuilder &builder,
2207 ::mlir::OperationState &
result, Value input, Value init,
2209 ArrayRef<NamedAttribute> attributes) {
2210 result.addOperands(input);
2211 result.addOperands(init);
2212 result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
2213 result.addAttributes(attributes);
2216 Type initType = init.
getType();
2217 if (llvm::isa<RankedTensorType>(initType))
2218 result.addTypes(initType);
2224void BroadcastOp::build(::mlir::OpBuilder &builder,
2225 ::mlir::OperationState &
result, Value input, Value init,
2226 ArrayRef<int64_t> dimensions,
2227 ArrayRef<NamedAttribute> attributes) {
2232ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
2234 parser,
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2246void BroadcastOp::getAsmResultNames(
2248 if (!getResults().empty())
2249 setNameFn(getResults().front(),
"broadcasted");
2252void BroadcastOp::print(OpAsmPrinter &p) {
2258LogicalResult BroadcastOp::verify() {
2259 ArrayRef<int64_t> dimensionsRef = getDimensions();
2261 auto inputType = getInput().getType();
2262 auto initType = getInit().getType();
2264 int64_t inputRank = inputType.getRank();
2265 int64_t initRank = initType.getRank();
2267 auto inputShape = inputType.getShape();
2268 auto initShape = initType.getShape();
2270 if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
2271 return emitOpError() <<
"input rank plus added dimensions does not "
2272 "match init rank. input rank: "
2274 <<
", dimensions size: " << dimensionsRef.size()
2275 <<
", init rank: " << initRank;
2277 for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2278 if (dim < 0 || dim >= initRank)
2280 <<
" is out of range. expected range: [0, "
2281 << initRank - 1 <<
"], got: " << dim;
2285 SmallVector<int64_t> dimMap;
2286 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2287 if (!llvm::is_contained(dimensionsRef, dim))
2288 dimMap.push_back(dim);
2291 for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2294 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2295 return emitOpError() <<
"input dim " << inputDimIdx
2296 <<
" should match init dim " << initDimIdx
2297 <<
". input: " << inputShape[inputDimIdx]
2298 <<
", init: " << initShape[initDimIdx];
2304SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2305 int64_t rank = getInit().getType().getRank();
2306 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2309ArrayAttr BroadcastOp::getIndexingMaps() {
2311 int64_t rank = getInit().getType().getRank();
2317void BroadcastOp::getEffects(
2318 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2333 auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2334 if (!defBroadcastOp)
2339 Value init = broadcastOp.getInit();
2343 for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
2344 if (!llvm::is_contained(dimensions, dim))
2345 dimMap.push_back(dim);
2347 for (
auto dim : defDimensions)
2348 foldedDims.push_back(dimMap[dim]);
2350 llvm::sort(foldedDims);
2352 broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2357void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2358 MLIRContext *context) {
2359 results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2366void linalg::YieldOp::print(OpAsmPrinter &p) {
2367 if (getNumOperands() > 0)
2368 p <<
' ' << getOperands();
2370 if (getNumOperands() > 0)
2371 p <<
" : " << getOperandTypes();
2374ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
2375 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2376 SmallVector<Type, 2> types;
2386static LogicalResult
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2387 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2388 return op.emitOpError(
"expected number of yield values (")
2389 << op.getNumOperands()
2390 <<
") to match the number of inits / outs operands of the enclosing "
2391 <<
"LinalgOp (" << linalgOp.getNumDpsInits() <<
")";
2393 for (
OpOperand &opOperand : op->getOpOperands()) {
2395 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2397 if (isa<MemRefType, RankedTensorType>(elementType))
2399 if (opOperand.get().getType() != elementType)
2400 return op.emitOpError(
"type of yield operand ")
2401 << (opOperand.getOperandNumber() + 1) <<
" ("
2402 << opOperand.get().getType() <<
") doesn't match "
2403 <<
"the element type of the enclosing linalg.generic op ("
2404 << elementType <<
")";
2409LogicalResult linalg::YieldOp::verify() {
2410 auto *parentOp = (*this)->getParentOp();
2411 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2412 return emitOpError(
"expected single non-empty parent region");
2414 if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2417 return emitOpError(
"expected parent op with LinalgOp interface");
2424LogicalResult IndexOp::verify() {
2425 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2427 return emitOpError(
"expected parent op with LinalgOp interface");
2428 if (linalgOp.getNumLoops() <= getDim())
2430 << getDim() <<
") to be lower than the number of loops ("
2431 << linalgOp.getNumLoops() <<
") of the enclosing LinalgOp";
2435OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2436 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2441 return OpFoldResult{};
2444 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2445 uint64_t dim = getDim();
2446 assert(dim < loopBounds.size() &&
"Dim is out of bounds");
2447 if (loopBounds[dim] == 1)
2448 return IntegerAttr::get(IndexType::get(
getContext()), 0);
2450 return OpFoldResult{};
2455#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2457#define GET_OP_CLASSES
2458#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2460#define GET_OP_CLASSES
2461#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2462#define GET_OP_CLASSES
2463#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2480 for (
unsigned i = 0; i < num; ++i)
2487 auto rangeA = llvm::make_range(a.begin(), a.end());
2488 auto rangeB = llvm::make_range(
b.begin(),
b.end());
2489 auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2490 return llvm::to_vector<4>(concatRanges);
2494 if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
2496 for (
auto size :
memref.getShape())
2503 if (
auto as =
memref.getMemorySpace()) {
2504 if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
2505 ss <<
"as" << attr.getInt();
2511 if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
2514 vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss <<
"x"; });
2527 assert(isa<LinalgOp>(op));
2529 std::string fun =
"";
2531 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2532 fun = stringifyEnum(ufa.getValue()).str() +
"_";
2533 }
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2534 fun = stringifyEnum(bfa.getValue()).str() +
"_";
2538 llvm::replace(name,
'.',
'_');
2539 llvm::raw_string_ostream ss(name);
2543 return std::string();
2558 LogicalResult matchAndRewrite(LinalgOp op,
2560 for (
OpOperand &opOperand : op->getOpOperands()) {
2564 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2567 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2578struct FoldTensorCastConsumerOp :
public OpRewritePattern<tensor::CastOp> {
2579 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2581 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2582 PatternRewriter &rewriter)
const override {
2586 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2593 if (castOp->getBlock() != linalgOp->getBlock())
2596 OpBuilder::InsertionGuard guard(rewriter);
2599 Location loc = linalgOp.getLoc();
2600 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2603 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2609 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2611 tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
2612 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2613 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2614 linalgOp.getDpsInits().end());
2615 outputOperands[resultNumber] = newOperand;
2616 newOperands.append(outputOperands.begin(), outputOperands.end());
2618 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2619 linalgOp->result_type_end());
2620 resultTypes[resultNumber] = resultType;
2621 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2624 Value castBack = tensor::CastOp::create(
2628 results[resultNumber] = castBack;
2637static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2638 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2639 for (OpOperand &opOperand : operands) {
2640 if (linalgOp.isScalar(&opOperand))
2642 Value src = opOperand.get();
2643 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2644 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2650 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2652 if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2653 Value castSource = castOp.getSource();
2654 auto castSourceType =
2655 llvm::dyn_cast<RankedTensorType>(castSource.
getType());
2656 if (castSourceType && castSourceType.hasStaticShape())
2657 sourceShape = castSourceType.getShape();
2663 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2664 if (sourceType.isDynamicDim(i))
2666 if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2667 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2677static void createNewOperandWithStaticSizes(
2678 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2679 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2680 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2681 bool &changeNeeded) {
2682 Value src = opOperand->
get();
2683 newOperands.push_back(src);
2684 if (linalgOp.isScalar(opOperand))
2686 auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
2687 Type resultType = sourceType;
2688 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2689 resultTypes.push_back(resultType);
2692 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2693 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2694 SmallVector<int64_t> newShape;
2697 bool newOperandNeeded =
false;
2698 for (
unsigned i = 0; i < sourceShape.size(); i++) {
2699 int64_t dimShape = sourceShape[i];
2700 AffineExpr dimExpr = sourceMap.
getResult(i);
2701 if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2702 newShape.push_back(dimShape);
2708 newShape.push_back(affineExprToSize[dimExpr]);
2709 newOperandNeeded =
true;
2711 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2712 sourceType.getEncoding());
2713 if (newOperandNeeded) {
2714 changeNeeded =
true;
2717 Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2719 newOperands[index] = newOperand;
2721 if (linalgOp.isDpsInit(opOperand))
2722 resultTypes.push_back(resultType);
2728struct InferStaticShapeOfOperands :
public OpInterfaceRewritePattern<LinalgOp> {
2729 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2731 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2732 PatternRewriter &rewriter)
const override {
2733 if (!linalgOp.hasPureTensorSemantics())
2737 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2738 return !map.isProjectedPermutation();
2743 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2744 Location loc = linalgOp.getLoc();
2748 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2750 SmallVector<Value> newOperands;
2751 SmallVector<Type> resultTypes;
2755 bool changeNeeded =
false;
2756 newOperands.reserve(linalgOp->getNumOperands());
2757 resultTypes.reserve(linalgOp.getNumDpsInits());
2760 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2761 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2762 affineExprToSize, linalgOp, newOperands,
2763 resultTypes, changeNeeded);
2772 Operation *newOp =
clone(rewriter, linalgOp, resultTypes, newOperands);
2773 SmallVector<Value> replacements;
2775 for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
2776 Value newResult = std::get<1>(it);
2777 Value oldResult = std::get<0>(it);
2778 Type newType = newResult.
getType();
2779 Type oldType = oldResult.
getType();
2780 replacements.push_back(
2781 (newType != oldType)
2782 ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2785 rewriter.
replaceOp(linalgOp, replacements);
2799LogicalResult SoftmaxOp::verify() {
2800 ShapedType inputType = getInputOperandType();
2801 ShapedType outputType = getOutputOperandType();
2803 ArrayRef<int64_t> inputShape = inputType.getShape();
2804 ArrayRef<int64_t> outputShape = outputType.getShape();
2808 int64_t inputRank = getInputOperandRank();
2809 int64_t dimension = getDimension();
2810 if ((dimension < 0) || (dimension >= inputRank))
2811 return emitOpError(
"incorrect dimension specified");
2816SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2817 int64_t operandRank = getInputOperandRank();
2818 SmallVector<Range> loopBounds(operandRank);
2819 Location loc = getLoc();
2822 Value source = getInput();
2823 for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
2824 loopBounds[dim].offset = zero;
2825 loopBounds[dim].size =
getDimValue(builder, loc, source, dim);
2826 loopBounds[dim].stride = one;
2831SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2832 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2833 utils::IteratorType::parallel);
2834 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2835 return iteratorTypes;
2838FailureOr<TilingResult>
2839SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2840 ArrayRef<OpFoldResult> offsets,
2841 ArrayRef<OpFoldResult> sizes) {
2842 int64_t rank = getInputOperandRank();
2844 SmallVector<OpFoldResult> strides(rank, oneAttr);
2845 SmallVector<Value> tiledOperands;
2846 Operation *inputSlice =
2847 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2849 return emitOpError(
"failed to compute input slice");
2851 tiledOperands.emplace_back(inputSlice->
getResult(0));
2852 Operation *outputSlice =
2853 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2855 return emitOpError(
"failed to compute output slice");
2857 tiledOperands.emplace_back(outputSlice->
getResult(0));
2859 SmallVector<Type, 4> resultTypes;
2860 if (hasPureTensorSemantics())
2861 resultTypes.push_back(tiledOperands[1].
getType());
2862 Operation *tiledOp =
2863 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2865 return TilingResult{
2868 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2871LogicalResult SoftmaxOp::getResultTilePosition(
2872 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2873 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2874 SmallVector<OpFoldResult> &resultSizes) {
2875 if (resultNumber == 0) {
2876 resultOffsets.assign(offsets.begin(), offsets.end());
2877 resultSizes.assign(sizes.begin(), sizes.end());
2884LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2889SoftmaxOp::reifyResultShapes(OpBuilder &
b,
2891 SmallVector<OpFoldResult> shapes;
2892 Location loc = getOperation()->getLoc();
2893 IRRewriter rewriter(
b);
2894 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2895 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2896 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2897 if (!outputShapedType.isDynamicDim(dim)) {
2899 shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
2906 reifiedReturnShapes.emplace_back(std::move(shapes));
2910void SoftmaxOp::getEffects(
2911 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2913 for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2914 if (!llvm::isa<MemRefType>(operand.
getType()))
2917 &getOperation()->getOpOperand(index), 0,
2922 for (OpOperand &operand : getDpsInitsMutable()) {
2923 if (!llvm::isa<MemRefType>(operand.get().
getType()))
2954static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2956 int64_t dim,
bool allParallel =
false) {
2958 utils::IteratorType::parallel);
2960 iteratorTypes[dim] = utils::IteratorType::reduction;
2964 for (
int i = 0; i < inputRank; i++) {
2971 return std::make_tuple(iteratorTypes, indexingMaps);
2976template <
typename T>
2979 auto inputType = cast<ShapedType>(input.
getType());
2981 int64_t inputRank = inputShape.size();
2982 auto [iteratorTypes, indexingMaps] =
2984 assert(indexingMaps.size() == 2 &&
2985 "We should have two maps: 1 for the input, 1 for the output");
2986 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
2988 auto genericOp = linalg::GenericOp::create(
2989 builder, loc, output.
getType(), input, output, indexingMaps,
2991 Value result = T::create(b, loc, args[0], args[1]);
2992 linalg::YieldOp::create(b, loc, result);
2994 return genericOp.getResult(0);
3002 auto inputType = cast<ShapedType>(input.
getType());
3004 int64_t inputRank = inputShape.size();
3006 builder, inputRank, dim,
true);
3007 assert(indexingMaps.size() == 2 &&
"We should have one map for each input");
3008 assert(indexingMaps[0].isIdentity() &&
"input map should be identity");
3010 indexingMaps.push_back(indexingMaps[0]);
3011 auto genericOp = linalg::GenericOp::create(
3013 indexingMaps, iteratorTypes,
3015 Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
3016 Value result = math::ExpOp::create(b, loc, diff);
3017 linalg::YieldOp::create(b, loc, result);
3019 return genericOp.getResult(0);
3029 auto inputType = cast<ShapedType>(numerator.
getType());
3031 int64_t inputRank = inputShape.size();
3033 builder, inputRank, dim,
true);
3034 assert(indexingMaps.size() == 2 &&
3035 "We should have one map for each input (2)");
3036 assert(indexingMaps[0].isIdentity() &&
"Numerator map should be identity");
3038 indexingMaps.push_back(indexingMaps[0]);
3039 auto genericOp = linalg::GenericOp::create(
3041 output, indexingMaps, iteratorTypes,
3043 Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3044 linalg::YieldOp::create(b, loc, result);
3046 return genericOp.getResult(0);
3068FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
3069 OpBuilder::InsertionGuard guard(
b);
3070 b.setInsertionPoint(*
this);
3071 Location loc = getLoc();
3072 Value input = getInput();
3073 ShapedType inputType = getInputOperandType();
3074 Type elementType = inputType.getElementType();
3075 int64_t reductionDim = getDimension();
3077 Value output = getOutput();
3078 dims.erase(dims.begin() + reductionDim);
3080 Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
3082 elementType,
b, loc,
3084 Value neutralForMaxFInit =
3085 linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
3097 linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
3103 buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
3104 return SmallVector<Value>{
result};
3111LogicalResult WinogradFilterTransformOp::verify() {
3112 auto filterType = cast<ShapedType>(getFilter().
getType());
3113 ArrayRef<int64_t> filterShape = filterType.getShape();
3114 int64_t filterH = filterShape[getFilterHDim()];
3115 int64_t filterW = filterShape[getFilterWDim()];
3116 WinogradConv2DFmr fmr = getFmr();
3120 if (filterH != r && filterH != 1)
3121 return emitOpError(
"expect filter height either equals to r or 1");
3122 if (filterW != r && filterW != 1)
3123 return emitOpError(
"expect filter width either equals to r or 1");
3124 if (filterH == 1 && filterW == 1)
3125 return emitOpError(
"expect either filter height or width equals to r");
3127 SmallVector<int64_t> expectedOutputShape;
3128 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3129 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3130 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3131 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3133 auto outputType = cast<ShapedType>(getOutput().
getType());
3134 ArrayRef<int64_t> outputShape = outputType.getShape();
3136 return emitOpError(
"the output shape is not expected");
3142WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3143 Location loc = getLoc();
3146 Value filter = getFilter();
3147 int64_t filterRank = getFilterOperandRank();
3148 SmallVector<Range> loopBounds(filterRank);
3149 for (
unsigned dim = 0; dim < filterRank; ++dim) {
3150 loopBounds[dim].offset = zeroAttr;
3151 loopBounds[dim].size =
getDimValue(builder, loc, filter, dim);
3152 loopBounds[dim].stride = oneAttr;
3157SmallVector<utils::IteratorType>
3158WinogradFilterTransformOp::getLoopIteratorTypes() {
3159 int64_t filterRank = getFilterOperandRank();
3160 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3161 utils::IteratorType::parallel);
3162 return iteratorTypes;
3165LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3166 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3167 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3168 SmallVector<OpFoldResult> &resultSizes) {
3170 ShapedType filterType = getFilterOperandType();
3171 ArrayRef<int64_t> filterShape = filterType.getShape();
3172 int64_t filterH = filterShape[getFilterHDim()];
3173 int64_t filterW = filterShape[getFilterWDim()];
3174 WinogradConv2DFmr fmr = getFmr();
3177 int64_t alpha = m + r - 1;
3178 int64_t alphaH = filterH != 1 ? alpha : 1;
3179 int64_t alphaW = filterW != 1 ? alpha : 1;
3183 resultOffsets.append(
3184 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3186 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3197FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3198 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3199 ArrayRef<OpFoldResult> sizes) {
3202 ShapedType filterType = getFilterOperandType();
3203 ArrayRef<int64_t> filterShape = filterType.getShape();
3204 int64_t filterH = filterShape[getFilterHDim()];
3205 int64_t filterW = filterShape[getFilterWDim()];
3208 SmallVector<Value> tiledOperands;
3209 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3211 sliceOffsets.append(
3212 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3213 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3214 sizes[getFilterCDim()]});
3215 int64_t filterRank = getFilterOperandRank();
3216 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3217 Location loc = getLoc();
3218 auto filterSlice = tensor::ExtractSliceOp::create(
3219 builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3220 tiledOperands.emplace_back(filterSlice);
3222 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3227 int64_t outputRank = getOutputOperandRank();
3228 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3229 auto outputSlice = tensor::ExtractSliceOp::create(
3230 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3231 tiledOperands.emplace_back(outputSlice);
3233 SmallVector<Type> resultTypes;
3234 resultTypes.push_back(tiledOperands[1].
getType());
3235 Operation *tiledOp =
3236 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3238 return TilingResult{
3241 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3248LogicalResult WinogradInputTransformOp::verify() {
3249 auto inputType = cast<ShapedType>(getInput().
getType());
3250 ArrayRef<int64_t> inputShape = inputType.getShape();
3251 int64_t inputH = inputShape[getInputHDim()];
3252 int64_t inputW = inputShape[getInputWDim()];
3253 WinogradConv2DFmr fmr = getFmr();
3256 int64_t tileSize = m + r - 1;
3258 auto outputType = cast<ShapedType>(getOutput().
getType());
3259 ArrayRef<int64_t> outputShape = outputType.getShape();
3260 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3261 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3263 SmallVector<int64_t> expectedOutputShape(6, inputH);
3264 if (ShapedType::isDynamic(inputH)) {
3265 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3266 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3268 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3269 expectedOutputShape[getOutputTileHDim()] =
3270 leftTransform ? (inputH - (r - 1)) / m : inputH;
3272 if (ShapedType::isDynamic(inputW)) {
3273 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3274 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3276 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3277 expectedOutputShape[getOutputTileWDim()] =
3278 rightTransform ? (inputW - (r - 1)) / m : inputW;
3280 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3281 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3284 return emitOpError(
"the output shape is not expected");
3290WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3291 Location loc = getLoc();
3294 Value output = getOutput();
3295 int64_t outputRank = getOutputOperandRank();
3296 SmallVector<Range> loopBounds(outputRank);
3297 for (
unsigned dim = 0; dim < outputRank; ++dim) {
3298 loopBounds[dim].offset = zeroAttr;
3300 loopBounds[dim].size =
getDimValue(builder, loc, output, dim);
3301 loopBounds[dim].stride = oneAttr;
3306SmallVector<utils::IteratorType>
3307WinogradInputTransformOp::getLoopIteratorTypes() {
3308 int64_t outputRank = getOutputOperandRank();
3309 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3310 utils::IteratorType::parallel);
3311 return iteratorTypes;
3314LogicalResult WinogradInputTransformOp::getResultTilePosition(
3315 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3316 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3317 SmallVector<OpFoldResult> &resultSizes) {
3319 ShapedType outputType = getOutputOperandType();
3320 ArrayRef<int64_t> outputShape = outputType.getShape();
3321 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3322 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3324 WinogradConv2DFmr fmr = getFmr();
3327 int64_t alpha = m + r - 1;
3328 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3329 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3334 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3335 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3336 offsets[getOutputCDim()]});
3337 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3338 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3339 sizes[getOutputCDim()]});
3350FailureOr<TilingResult>
3351WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3352 ArrayRef<OpFoldResult> offsets,
3353 ArrayRef<OpFoldResult> sizes) {
3355 WinogradConv2DFmr fmr = getFmr();
3359 ShapedType outputType = getOutputOperandType();
3360 ArrayRef<int64_t> outputShape = outputType.getShape();
3361 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3362 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3364 Location loc = getLoc();
3366 auto identityAffineMap =
3368 auto offsetAffineMap =
3371 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3372 offsets[getOutputTileHDim()]);
3374 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3375 offsets[getOutputTileWDim()]);
3379 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3381 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3383 SmallVector<Value> tiledOperands;
3384 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3386 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3387 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3388 sliceOffsets.append(
3389 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3390 OpFoldResult sizeH =
3391 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3392 OpFoldResult sizeW =
3393 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3395 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3396 int64_t inputRank = getInputOperandRank();
3397 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3398 auto inputSlice = tensor::ExtractSliceOp::create(
3399 builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3400 tiledOperands.emplace_back(inputSlice);
3402 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3407 int64_t outputRank = getOutputOperandRank();
3408 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3409 auto outputSlice = tensor::ExtractSliceOp::create(
3410 builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3411 tiledOperands.emplace_back(outputSlice);
3413 SmallVector<Type> resultTypes;
3414 resultTypes.push_back(tiledOperands[1].
getType());
3415 Operation *tiledOp =
3416 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3418 return TilingResult{
3421 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3428LogicalResult WinogradOutputTransformOp::verify() {
3429 auto valueType = cast<ShapedType>(getValue().
getType());
3430 ArrayRef<int64_t> valueShape = valueType.getShape();
3431 int64_t valueH = valueShape[getValueAlphaHDim()];
3432 int64_t valueW = valueShape[getValueAlphaWDim()];
3433 int64_t valueTileH = valueShape[getValueTileHDim()];
3434 int64_t valueTileW = valueShape[getValueTileWDim()];
3435 WinogradConv2DFmr fmr = getFmr();
3438 bool leftTransform = valueH != 1;
3439 bool rightTransform = valueW != 1;
3441 int64_t outputRank = getOutputOperandRank();
3442 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3443 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3444 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3446 if (valueH != (leftTransform ? m + r - 1 : 1))
3447 return emitOpError(
"expect input height equals to input tile size");
3448 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3450 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3451 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3453 if (valueW != (rightTransform ? m + r - 1 : 1))
3454 return emitOpError(
"expect input width equals to input tile size");
3455 expectedOutputShape[getOutputWDim()] =
3456 (rightTransform ? m : 1) * valueTileW;
3458 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3459 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3461 auto outputType = cast<ShapedType>(getOutput().
getType());
3462 ArrayRef<int64_t> outputShape = outputType.getShape();
3464 return emitOpError(
"the output shape is not expected");
3470WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3471 Location loc = getLoc();
3474 Value value = getValue();
3475 int64_t valueRank = getValueOperandRank();
3476 SmallVector<Range> loopBounds(valueRank);
3477 for (
unsigned dim = 0; dim < valueRank; ++dim) {
3478 loopBounds[dim].offset = zeroAttr;
3480 loopBounds[dim].size =
getDimValue(builder, loc, value, dim);
3481 loopBounds[dim].stride = oneAttr;
3486SmallVector<utils::IteratorType>
3487WinogradOutputTransformOp::getLoopIteratorTypes() {
3488 int64_t valueRank = getValueOperandRank();
3489 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3490 utils::IteratorType::parallel);
3491 return iteratorTypes;
3494LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3495 OpBuilder &builder,
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3496 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3497 SmallVector<OpFoldResult> &resultSizes) {
3498 WinogradConv2DFmr fmr = getFmr();
3502 Location loc = getLoc();
3504 auto identityAffineMap =
3509 ShapedType valueType = getValueOperandType();
3510 ArrayRef<int64_t> valueShape = valueType.getShape();
3511 int64_t valueH = valueShape[0];
3512 int64_t valueW = valueShape[1];
3514 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3515 offsets[getValueTileHDim()]);
3517 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3518 offsets[getValueTileWDim()]);
3520 builder, loc, affineMap, sizes[getValueTileHDim()]);
3522 builder, loc, affineMap, sizes[getValueTileWDim()]);
3525 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3526 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3527 OpFoldResult sizeH =
3528 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3529 OpFoldResult sizeW =
3530 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3532 resultOffsets.append(
3533 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3535 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3545FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3546 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3547 ArrayRef<OpFoldResult> sizes) {
3550 Location loc = getLoc();
3551 SmallVector<Value> tiledOperands;
3552 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3554 ShapedType valueType = getValueOperandType();
3555 ArrayRef<int64_t> valueShape = valueType.getShape();
3556 int64_t alphaH = valueShape[getValueAlphaHDim()];
3557 int64_t alphaW = valueShape[getValueAlphaWDim()];
3561 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3562 offsets[getValueTileWDim()], offsets[getValueNDim()],
3563 offsets[getValueFDim()]});
3564 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3565 sizes[getValueTileWDim()], sizes[getValueNDim()],
3566 sizes[getValueFDim()]});
3567 int64_t valueRank = getValueOperandRank();
3568 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3569 auto valueSlice = tensor::ExtractSliceOp::create(
3570 builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3571 tiledOperands.emplace_back(valueSlice);
3573 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3578 int64_t outputRank = getOutputOperandRank();
3579 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3580 auto outputSlice = tensor::ExtractSliceOp::create(
3581 builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3582 tiledOperands.emplace_back(outputSlice);
3584 SmallVector<Type> resultTypes;
3585 resultTypes.push_back(tiledOperands[1].
getType());
3586 Operation *tiledOp =
3587 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3589 return TilingResult{
3592 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3606 llvm::set_union(explicitSet, defaultSet);
3607 return explicitSet == defaultSet;
3627 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3629 auto opIndexingMap = opIndexingMaps[opIndex];
3630 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3633 return matmulOp->emitOpError()
3634 <<
"Unexpected dim expression in map result.";
3637 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3638 return matmulOp->emitOpError()
3639 <<
"Invalid broadcast requested, should be (d2).";
3648template <
typename OpTy>
3651 AffineMap defaultIndexingMap,
bool isLHS) {
3652 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3653 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3654 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3657 return batchVariantMatmulOp->emitOpError()
3658 <<
"Unexpected result dim expression (outside the set of default "
3663 return batchVariantMatmulOp->emitOpError()
3664 <<
"no. of result dim expressions exceeds 3.";
3666 auto hasValidBatchDim = [](
AffineMap map) {
3673 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3674 return batchVariantMatmulOp->emitOpError()
3675 <<
"Invalid broadcast requested.";
3676 }
else if (!hasValidBatchDim(opIndexingMap)) {
3677 return batchVariantMatmulOp->emitOpError()
3678 <<
"Invalid batch dimension expression.";
3686template <
typename OpTy>
3689 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3690 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3691 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3692 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3695 return batchVariantMatmulOp->emitOpError()
3696 <<
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
3699 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3701 return batchVariantMatmulOp->emitOpError()
3702 <<
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
3706 auto areValidOutputResultDim = [&](
AffineMap outputMap) {
3707 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3708 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3709 outputMap.getResult(1).isFunctionOfDim(1) &&
3710 outputMap.getResult(2).isFunctionOfDim(2)
3711 : outputMap.getResult(0).isFunctionOfDim(1) &&
3712 outputMap.getResult(1).isFunctionOfDim(2);
3715 if (!areValidOutputResultDim(opIndexingMap)) {
3716 return batchVariantMatmulOp->emitOpError()
3717 <<
"Invalid output map result dimension.";
3726template <
typename OpTy>
3731 batchVariantMatmulOp.getIndexingMapsArray();
3733 batchVariantMatmulOp.getDefaultIndexingMaps(
3734 batchVariantMatmulOp->getContext());
3736 if (opIndexingMaps.size() != 3)
3737 return batchVariantMatmulOp->emitOpError()
3738 <<
"Indexing_map attribute must have 3 affine maps.";
3740 auto opIndexingMap = opIndexingMaps[opIndex];
3741 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3749 defaultIndexingMap, opIndex == 0)))
3759 if (m == 2 && r == 3)
3760 return WinogradConv2DFmr::F_2_3;
3761 if (m == 4 && r == 3)
3762 return WinogradConv2DFmr::F_4_3;
3763 if (m == 2 && r == 5)
3764 return WinogradConv2DFmr::F_2_5;
3765 return std::nullopt;
3770 case WinogradConv2DFmr::F_2_3:
3772 case WinogradConv2DFmr::F_4_3:
3774 case WinogradConv2DFmr::F_2_5:
3777 llvm_unreachable(
"Unkown WinogradConv2DFmr");
3784static FailureOr<SmallVector<SmallVector<int64_t>>>
3787 for (
auto map : maps) {
3788 AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3792 for (
auto result : attr.getAffineMap().getResults()) {
3793 auto dim = dyn_cast<AffineDimExpr>(
result);
3796 pos.push_back(dim.getPosition());
3798 positions.push_back(pos);
3811 return indexingMaps;
3814bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3815 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3818 if (maps.size() != 3)
3823 return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3824 (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3825 (*positions)[2] == SmallVector<int64_t>{0, 1};
3828SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3829 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3830 utils::IteratorType::parallel,
3831 utils::IteratorType::reduction};
3834unsigned MatmulOp::getNumRegionArgs() {
return 3; }
3836std::string MatmulOp::getLibraryCallName() {
3840bool MatmulOp::hasDynamicIndexingMaps() {
return true; }
3844bool MatmulOp::hasUserDefinedMaps() {
3845 SmallVector<AffineMap, 3> defaultMaps =
3847 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3848 return defaultMaps != explicitMaps;
3853void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
3854 ArrayRef<NamedAttribute> attrs,
3857 emitError() <<
"MatmulOp regionBuilder expects 3 args, got "
3862 "MatmulOp regionBuilder expects 3 args");
3863 RegionBuilderHelper helper(
b, block);
3864 SmallVector<Value> yields;
3866 TypeFn castVal = TypeFn::cast_signed;
3867 const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
3868 return attr.
getName() ==
"cast";
3870 if (castIter != attrs.end()) {
3871 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3879 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2,
emitError);
3882 Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
3886 yields.push_back(value4);
3887 helper.yieldOutputs(yields);
3897bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3898 assert(bcastMap.
getNumResults() == 1 &&
"Expected single result dim expr.");
3899 AffineExpr expr = bcastMap.
getResult(0);
3909 ArrayAttr arrayAttr;
3913 if (llvm::any_of(arrayAttr,
3914 [](
auto elt) {
return !dyn_cast<AffineMapAttr>(elt); }))
3916 <<
"element of indexing_maps array is not an affine_map";
3923 if (failed(indexingMapsAttr))
3926 if (*indexingMapsAttr ==
nullptr) {
3927 auto indexingMapAttrs = llvm::map_to_vector(
3928 MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
3933 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
3935 MatmulOp::getRegionBuilder());
3938void MatmulOp::print(OpAsmPrinter &p) {
3939 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3940 MatmulOp::getDefaultIndexingMaps(
getContext()),
3941 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
3942 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3943 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3945 std::array<StringRef, 3> elidedAttrs = {
3946 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
3952LogicalResult MatmulOp::verify() {
3954 if (!hasUserDefinedMaps())
3957 for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
3964LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3968void MatmulOp::getEffects(
3969 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3971 if (hasPureTensorSemantics())
3980SmallVector<AffineMap>
3981MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3982 AffineExpr d0, d1, d2;
3988 return {mapLHS, mapRHS, mapOut};
3992 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3995 if (maps.size() != 3)
3998 if (failed(positions))
4010 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4018 build(builder, state, inputs, outputs, attributes);
4019 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4020 assert(res &&
"builder didn't return the right type");
4030 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4039 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4040 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4041 assert(res &&
"builder didn't return the right type");
4051 result.addAttribute(
"cast", cast);
4053 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4062 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4063 auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
4064 assert(res &&
"builder didn't return the right type");
4069 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4071 op->
getAttr(
"indexing_maps"));
4075MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4082 return {mapLHS, mapRHS, mapOut};
4086 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4089 if (maps.size() != 3)
4092 if (failed(positions))
4104 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4112 build(builder, state, inputs, outputs, attributes);
4113 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4114 assert(res &&
"builder didn't return the right type");
4124 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4133 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4134 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4135 assert(res &&
"builder didn't return the right type");
4145 result.addAttribute(
"cast", cast);
4147 MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4156 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4157 auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
4158 assert(res &&
"builder didn't return the right type");
4163 return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4165 op->
getAttr(
"indexing_maps"));
4169BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4176 return {mapLHS, mapRHS, mapOut};
4180 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4183 if (maps.size() != 3)
4186 if (failed(positions))
4197 BatchMatmulOp::getRegionBuilder(),
4198 getDefaultIndexingMaps(builder));
4206 build(builder, state, inputs, outputs, attributes);
4207 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4208 assert(res &&
"builder didn't return the right type");
4217 BatchMatmulOp::getRegionBuilder(),
4218 getDefaultIndexingMaps(builder));
4227 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4228 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4229 assert(res &&
"builder didn't return the right type");
4237 result.addAttribute(
"cast", cast);
4239 BatchMatmulOp::getRegionBuilder(),
4240 getDefaultIndexingMaps(builder));
4249 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4250 auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
4251 assert(res &&
"builder didn't return the right type");
4256 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4258 op->
getAttr(
"indexing_maps"));
4262BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
4269 return {mapLHS, mapRHS, mapOut};
4273 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4276 if (maps.size() != 3)
4279 if (failed(positions))
4290 BatchMatmulOp::getRegionBuilder(),
4291 getDefaultIndexingMaps(builder));
4299 build(builder, state, inputs, outputs, attributes);
4300 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4301 assert(res &&
"builder didn't return the right type");
4310 BatchMatmulOp::getRegionBuilder(),
4311 getDefaultIndexingMaps(builder));
4320 build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4321 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4322 assert(res &&
"builder didn't return the right type");
4330 result.addAttribute(
"cast", cast);
4332 BatchMatmulOp::getRegionBuilder(),
4333 getDefaultIndexingMaps(builder));
4342 build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4343 auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
4344 assert(res &&
"builder didn't return the right type");
4349 return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4351 op->
getAttr(
"indexing_maps"));
4359 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4370 auto dimExpr = dyn_cast<AffineDimExpr>(
result);
4371 assert(dimExpr &&
"affine_map is a projected permutation");
4372 dimsInOutput[dimExpr.getPosition()] =
true;
4376 for (
auto dimOccursInOutput : dimsInOutput)
4377 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4378 : utils::IteratorType::reduction);
4380 return iteratorTypes;
4383unsigned ContractOp::getNumRegionArgs() {
return 3; }
4386void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b,
Block &block,
4387 ArrayRef<NamedAttribute> attrs,
4390 emitError() <<
"ContractOp regionBuilder expects 3 args, got "
4395 "ContractOp regionBuilder expects 3 args");
4396 RegionBuilderHelper helper(
b, block);
4398 TypeFn castSignedness = TypeFn::cast_signed;
4399 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4400 return attr.
getName() ==
"cast";
4402 if (castIter != attrs.end()) {
4403 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4409 Value lhsAtOutType =
4410 helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
4411 Value rhsAtOutType =
4412 helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
4413 Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4415 if (!productAtOutType)
4421 helper.yieldOutputs({
result});
4424ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
4426 if (
failed(indexingMapsAttr) || *indexingMapsAttr ==
nullptr)
4428 "expected 'indexing_maps' attribute");
4429 result.addAttribute(
"indexing_maps", *indexingMapsAttr);
4435void ContractOp::print(OpAsmPrinter &p) {
4436 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4438 p, getOperation(), getInputs(), getOutputs(),
4439 {
"indexing_maps",
"operandSegmentSizes"});
4442LogicalResult ContractOp::verify() {
4443 int iterationSpaceDims = -1;
4448 SmallVector<size_t> inOccurrences;
4449 SmallVector<size_t> outOccurrences;
4452 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4453 bool isInput) -> LogicalResult {
4456 return emitError(
"provided affine_map is not a projected permutation");
4459 if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
4461 return emitError(
"ranks of shaped operand and results of corresponding "
4462 "affine_map differ");
4464 return emitError(
"affine_map specifies shaped access while operand has "
4469 if (iterationSpaceDims == -1) {
4471 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4472 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4473 }
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
4474 return emitError(
"iteration spaces of provided affine_maps differ");
4478 for (AffineExpr affineExpr : affineMap.
getResults()) {
4479 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4481 llvm_unreachable(
"affine_map is a projected permutation");
4484 inOccurrences[affineDimExpr.getPosition()] += 1;
4486 outOccurrences[affineDimExpr.getPosition()] += 1;
4492 for (
auto &&[affineMap, operandType, isInput] :
4493 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4494 SmallVector<bool>{
true,
true,
false})) {
4495 if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4499 bool hasContractingDim =
false;
4500 for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4501 size_t inOccCount = inOccurrences[dimIndex];
4502 size_t outOccCount = outOccurrences[dimIndex];
4505 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4507 if (inOccCount == 0 && outOccCount == 0)
4508 return emitError() <<
"iteration space dim at index " << dimIndex
4509 <<
" not used to access any operand";
4520 if (inOccCount == 1 && outOccCount != 1)
4522 <<
"iteration space dim at index " << dimIndex
4523 <<
" is neither a contracting dim nor of parallel iteration type";
4526 if (!hasContractingDim)
4527 return emitError(
"'indexing_maps' do not specify a contracting dimension");
4532LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4536void ContractOp::getEffects(
4537 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4539 if (hasPureTensorSemantics())
4551SmallVector<AffineMap>
4552BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4553 AffineExpr d0, d1, d2, d3;
4554 SmallVector<AffineMap> indexingMaps;
4556 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
4557 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
4558 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
4559 return indexingMaps;
4562bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4563 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4566 if (maps.size() != 3)
4571 return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4572 (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4573 (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4576SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4577 return SmallVector<utils::IteratorType>{
4578 utils::IteratorType::parallel, utils::IteratorType::parallel,
4579 utils::IteratorType::parallel, utils::IteratorType::reduction};
4582unsigned BatchMatmulOp::getNumRegionArgs() {
return 3; }
4584std::string BatchMatmulOp::getLibraryCallName() {
4590bool BatchMatmulOp::hasUserDefinedMaps() {
4591 SmallVector<AffineMap, 3> defaultMaps =
4593 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4594 return defaultMaps != explicitMaps;
4604bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
bool isLHS) {
4606 "Expected less than 3 result dim expr.");
4607 bool isValid =
false;
4608 enum Indices { batchPos, mPos, nPos, kPos };
4610 AffineExpr expr = bcastMap.
getResult(0);
4613 AffineExpr expr0 = bcastMap.
getResult(0);
4614 AffineExpr expr1 = bcastMap.
getResult(1);
4619 : ((expr0.isFunctionOfDim(batchPos) &&
4620 expr1.isFunctionOfDim(kPos)) ||
4621 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4626void BatchMatmulOp::regionBuilder(
4627 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4630 emitError() <<
"BatchMatmulOp regionBuilder expects 3 args, got "
4635 "BatchMatmulOp regionBuilder expects 3 args");
4636 RegionBuilderHelper helper(
b, block);
4637 SmallVector<Value> yields;
4639 TypeFn castVal = TypeFn::cast_signed;
4640 auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
4641 return attr.
getName() ==
"cast";
4643 if (castIter != attrs.end()) {
4644 if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4649 Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
4650 Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
4651 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4653 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
4654 yields.push_back(addVal);
4655 helper.yieldOutputs(yields);
4658ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
4659 SmallVector<Attribute, 3> indexingMapsAttr;
4671 if (!isa<AffineMapAttr>(mapAttr)) {
4673 "expected affine map attribute");
4675 indexingMapsAttr.push_back(mapAttr);
4685 if (indexingMapsAttr.empty()) {
4686 indexingMapsAttr = llvm::map_to_vector(
4687 BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
4688 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4690 result.addAttribute(
"indexing_maps",
4693 return ::parseNamedStructuredOp(parser,
result,
4694 BatchMatmulOp::getNumRegionArgs(),
4695 BatchMatmulOp::getRegionBuilder());
4698void BatchMatmulOp::print(OpAsmPrinter &p) {
4699 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4700 BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
4701 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4702 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4703 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4705 std::array<StringRef, 3> elidedAttrs = {
4706 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
4712LogicalResult BatchMatmulOp::verify() {
4715 if (!hasUserDefinedMaps())
4718 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
4725LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4726 SmallVectorImpl<OpFoldResult> &) {
4730void BatchMatmulOp::getEffects(
4731 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4733 if (hasPureTensorSemantics())
4747struct ArityGroupAndKind {
4749 ElementwiseArityGroup arityGroup;
4755 TernaryFn ternaryFn;
4759unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4760 return static_cast<unsigned>(arityGroup);
4765 constexpr int lastUnary =
static_cast<int>(ElementwiseCaseLimits::LastUnary);
4766 constexpr int lastBinary =
4767 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4768 constexpr int lastTernary =
4769 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4771 int val =
static_cast<int>(kind);
4772 ArityGroupAndKind
result;
4774 if (val < lastUnary) {
4775 result.arityGroup = ElementwiseArityGroup::Unary;
4776 result.kind.unaryFn =
static_cast<UnaryFn
>(val);
4779 if (val < lastBinary) {
4780 result.arityGroup = ElementwiseArityGroup::Binary;
4781 result.kind.binaryFn =
static_cast<BinaryFn
>(val - lastUnary);
4784 if (val >= lastTernary) {
4785 llvm_unreachable(
"unhandled ElementwiseFn");
4787 result.arityGroup = ElementwiseArityGroup::Ternary;
4788 result.kind.ternaryFn =
static_cast<TernaryFn
>(val - lastBinary);
4793 auto rank = getResultRank();
4798ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps,
unsigned numDims,
4804ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
4807 mlir::linalg::ElementwiseKind elemwiseKindVal;
4812 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4813 if (!elemwiseKindAttr)
4815 "expected ElementwiseKind attribute");
4816 elemwiseKindVal = elemwiseKindAttr.getValue();
4819 "expected operation 'kind' attribute");
4822 "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
4825 SmallVector<Attribute, 3> indexingMapsAttr;
4835 if (!isa<AffineMapAttr>(mapAttr))
4837 "expected affine map attribute");
4838 indexingMapsAttr.push_back(mapAttr);
4849 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
4851 ElementwiseOp::getRegionBuilder())) {
4853 "unable to parse elemwise op");
4857 if (indexingMapsAttr.empty()) {
4860 auto resultType =
result.operands[
result.operands.size() - 1].getType();
4861 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4864 "return type needs to be shaped type");
4865 auto numDims = shapedType.getRank();
4866 indexingMapsAttr = llvm::map_to_vector(
4867 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4869 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4872 result.addAttribute(
"indexing_maps",
4877void ElementwiseOp::print(OpAsmPrinter &p) {
4880 SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes",
"kind",
4884 unsigned numDims = getResultRank();
4886 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4887 ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
4889 [](AffineMap map) -> Attribute {
return AffineMapAttr::get(map); });
4891 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4892 p <<
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4900void ElementwiseOp::regionBuilder(
4901 ImplicitLocOpBuilder &
b,
Block &block, ArrayRef<NamedAttribute> attrs,
4903 ElementwiseKind elemwiseKind;
4904 for (
auto attr : attrs) {
4905 if (attr.getName() ==
b.getStringAttr(
"kind")) {
4906 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4907 assert(kindAttr &&
"op kind attribute incorrectly set");
4908 elemwiseKind = kindAttr.getValue();
4914 auto arityGroup = groupAndKind.arityGroup;
4915 auto kind = groupAndKind.kind;
4917 getArityGroupAsUInt(arityGroup) + 1 ) {
4918 emitError() <<
"Elementwise regionBuilder expects "
4919 << (getArityGroupAsUInt(arityGroup) + 1) <<
" args, got "
4924 getArityGroupAsUInt(arityGroup) + 1
4925 &&
"Elementwise regionBuilder number of block args mismatch");
4927 RegionBuilderHelper helper(
b, block);
4928 SmallVector<Value> yields;
4931 if (arityGroup == ElementwiseArityGroup::Unary) {
4934 }
else if (arityGroup == ElementwiseArityGroup::Binary) {
4938 }
else if (arityGroup == ElementwiseArityGroup::Ternary) {
4943 assert(
false &&
"found unhandled category in elemwise");
4946 yields.push_back(
result);
4947 helper.yieldOutputs(yields);
4950LogicalResult ElementwiseOp::fold(FoldAdaptor,
4951 SmallVectorImpl<OpFoldResult> &) {
4955void ElementwiseOp::getEffects(
4956 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4958 if (hasPureTensorSemantics())
4971template <
typename OpTy,
typename>
4974 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4975 ? packOrUnPack.getDestType()
4976 : packOrUnPack.getSourceType();
4977 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
4978 ? packOrUnPack.getSourceType()
4979 : packOrUnPack.getDestType();
4981 packedType.getShape().take_front(unpackedType.getRank()));
4982 if (!packOrUnPack.getOuterDimsPerm().empty()) {
5004 for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
5006 .take_back(mixedTiles.size()),
5009 if (
shape == ShapedType::kDynamic) {
5010 newMixedTileSizes.push_back(std::get<1>(it));
5017 if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
5019 newMixedTileSizes.push_back(
tile);
5022 "tile size and dim size don't match!");
5023 newMixedTileSizes.push_back(
5028 return newMixedTileSizes;
5031template <
typename OpTy>
5035 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5036 "applies to only pack or unpack operations");
5037 int64_t destRank = op.getDestRank();
5039 reifiedReturnShapes[0] =
5044template <
typename OpTy>
5046 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5047 "applies to only pack or unpack operations");
5051 assert(tiles.size() == dimsToTile.size() &&
5052 "tiles must match indices of dimension to block");
5054 for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5055 dimAndTileMapping[dimsToTile[i]] = tiles[i];
5056 return dimAndTileMapping;
5059template <
typename OpTy>
5061 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5062 "applies to only pack or unpack operations");
5065 unsigned dynamicValIndex = 0;
5066 for (
int64_t staticTile : op.getStaticInnerTiles()) {
5067 if (ShapedType::isStatic(staticTile))
5070 mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5072 return mixedInnerTiles;
5075template <
typename OpTy>
5077 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5078 "applies to only pack or unpack operations");
5091 size_t dimsPosSize = dimsPos.size();
5092 if (dimsPosSize > rank)
5095 if (dimsPosSize != uniqued.size())
5097 return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
5098 return dimPos < 0 || dimPos >=
static_cast<int64_t>(rank);
5102template <
typename OpTy>
5104 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5105 "applies to only pack or unpack operations");
5106 Operation *op = packOrUnPack.getOperation();
5114 if (!packOrUnPack.getSourceType().hasRank() ||
5115 !packOrUnPack.getDestType().hasRank())
5116 return op->
emitError(
"expected both source and destination to have rank");
5119 if (!packOrUnPack.hasPureBufferSemantics() &&
5120 !packOrUnPack.hasPureTensorSemantics())
5121 return op->
emitError(
"mixing tensor and buffer semantics is not allowed");
5122 const unsigned numResults = packOrUnPack.getNumResults();
5123 if (packOrUnPack.hasPureTensorSemantics() && numResults != 1)
5124 return op->
emitError(
"expected 1 result, got ") << numResults;
5125 if (packOrUnPack.hasPureBufferSemantics() && numResults != 0)
5126 return op->
emitError(
"expected 0 results, got ") << numResults;
5130 if (hasZeros(mixedTiles))
5131 return op->
emitError(
"invalid zero tile factor");
5134 ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
5135 ? packOrUnPack.getSourceType()
5136 : packOrUnPack.getDestType();
5137 size_t unpackedRank = unpackedType.getRank();
5141 return op->
emitError(
"invalid inner_dims_pos vector");
5143 return op->
emitError(
"invalid outer_dims_perm vector");
5144 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5145 return op->
emitError(
"outer_dims_perm must be a permutation or empty");
5149 if (mixedTiles.size() > unpackedRank) {
5150 return op->
emitError(
"tiling factors must be less than or equal to the "
5151 "input rank for pack or output rank for unpack");
5153 if (mixedTiles.size() != innerDimsPos.size()) {
5155 "tiling factors must equal the number of dimensions to tile");
5158 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5159 ? packOrUnPack.getDestType()
5160 : packOrUnPack.getSourceType();
5161 size_t packedRank = packedType.getRank();
5163 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5164 if (expectedPackedRank != packedRank) {
5166 "packed rank != (unpacked rank + num tiling factors), got ")
5167 << packedRank <<
" != " << expectedPackedRank;
5174 unpackedType.getShape(), packOrUnPack.getStaticTiles(),
5175 packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
5177 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5179 [](std::tuple<int64_t, OpFoldResult> it) {
5180 int64_t shape = std::get<0>(it);
5181 if (Attribute attr =
5182 llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5183 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5184 int64_t staticTileSize = intAttr.getValue().getSExtValue();
5185 return shape == staticTileSize;
5187 return ShapedType::isDynamic(
shape);
5189 return op->emitError(
"mismatch in inner tile sizes specified and shaped of "
5190 "tiled dimension in the packed type");
5194 auto elementType = unpackedType.getElementType();
5195 Type expectedType, actualType;
5196 if (packOrUnPack.hasPureTensorSemantics()) {
5197 expectedType = RankedTensorType::get(expectedPackedShape, elementType);
5198 actualType = RankedTensorType::get(packedType.getShape(), elementType);
5200 expectedType = MemRefType::get(expectedPackedShape, elementType);
5201 actualType = MemRefType::get(packedType.getShape(), elementType);
5203 return op->emitError(
"expected ")
5204 << expectedType <<
" for the packed domain value, got "
5217struct PackOrUnPackTransposeResult {
5224template <
typename OpTy>
5225static PackOrUnPackTransposeResult
5229 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5230 "applies to only pack or unpack operations");
5231 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5232 "some permutation must be non-empty");
5233 PackOrUnPackTransposeResult metadata;
5234 metadata.innerDimsPos =
5236 metadata.innerTiles =
5238 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5239 ? packOrUnPackOp.getSourceRank()
5240 : packOrUnPackOp.getDestRank();
5241 metadata.outerDimsPerm =
5242 packOrUnPackOp.getOuterDimsPerm().empty()
5243 ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5245 if (!innerPermutation.empty()) {
5246 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5248 "invalid inner permutation");
5252 if (!outerPermutation.empty()) {
5253 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5255 "invalid outer permutation");
5266 if (!getResults().empty())
5267 setNameFn(getResult(),
"pack");
5277 Type sourceType, destType, resultType;
5297 if (parser.parseInteger(value))
5299 outerDimsPermVec.push_back(value);
5312 if (parser.parseInteger(value))
5314 innerDimsPosVec.push_back(value);
5326 for (
auto val : staticTilesAttr.
asArrayRef())
5327 staticTiles.push_back(val);
5344 bool isMemRef = llvm::isa<MemRefType>(sourceType);
5347 "pack/unpack requires '->' and destination type");
5351 resultType = destType;
5357 if (!paddingValue.empty() &&
5362 if (!dynamicTiles.empty() &&
5367 result.addAttribute(
"static_inner_tiles",
5369 result.addAttribute(
"inner_dims_pos", innerDimsPos);
5371 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
5374 1, 1,
static_cast<int32_t
>(paddingValue.size()),
5375 static_cast<int32_t
>(dynamicTiles.size())};
5376 result.addAttribute(
"operandSegmentSizes",
5380 result.addTypes(resultType);
5386 p <<
" " << getSource();
5388 if (getPaddingValue()) {
5389 p <<
" padding_value(" << getPaddingValue() <<
" : "
5390 << getPaddingValue().getType() <<
")";
5393 if (!getOuterDimsPerm().empty()) {
5394 p <<
" outer_dims_perm = [";
5395 llvm::interleaveComma(getOuterDimsPerm(), p);
5399 p <<
" inner_dims_pos = [";
5400 llvm::interleaveComma(getInnerDimsPos(), p);
5403 p <<
" inner_tiles = ";
5406 p <<
" into " << getDest();
5409 {
"static_inner_tiles",
"inner_dims_pos",
5410 "outer_dims_perm",
"operandSegmentSizes"});
5412 p <<
" : " << getSource().getType();
5413 p <<
" -> " << getDest().getType();
5419 std::optional<Value> paddingValue,
5421 assert(innerDimsPos.size() == innerTiles.size() &&
5422 "number of tile sizes specified must match the specified number of "
5423 "original dimensions to be tiled");
5427 build(builder, state, dest.
getType(), source, dest,
5428 paddingValue ? *paddingValue :
nullptr,
5429 outerDimsPerm.empty() ?
nullptr
5436PackOp::reifyResultShapes(
OpBuilder &builder,
5438 if (!hasPureTensorSemantics())
5456 ShapedType inputType = getSourceType();
5457 int64_t inputRank = inputType.getRank();
5458 return getDestType().getShape().take_front(inputRank);
5462 auto innerDimsPos = getInnerDimsPos();
5469 if (!outerDimPermInv.empty())
5473 for (
auto index : innerDimsPos)
5474 res.push_back(outerDims[
index]);
5485 outputShape.take_front(inputShape.size()));
5486 if (!outerDimsPerm.empty()) {
5487 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5488 "expected output and outer_dims_perm to have same size");
5492 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5493 if (ShapedType::isDynamic(inputShape[pos]))
5497 if (!constantTile) {
5498 if (ShapedType::isStatic(outputTileSizes[pos]) &&
5499 (inputShape[pos] % outputTileSizes[pos] != 0))
5501 }
else if (inputShape[pos] % (*constantTile) != 0) {
5514 outputShape.take_front(inputShape.size()));
5515 if (!outerDimsPerm.empty()) {
5516 assert(outerDimsPerm.size() == outputTileSizes.size() &&
5517 "expected output and outer_dims_perm to have same size");
5521 for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5522 if (ShapedType::isDynamic(inputShape[pos]) ||
5523 ShapedType::isDynamic(outputTileSizes[pos]))
5528 if (inputShape[pos] % (*constantTile) != 0)
5534LogicalResult PackOp::verify() {
5541 auto paddingValue = getPaddingValue();
5545 << getSourceType().getElementType()
5546 <<
" but got: " << paddingValue.getType();
5549 if (!paddingValue &&
5550 requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
5551 getDestType().
getShape(), getOuterDimsPerm(),
5554 "invalid tile factor or output size provided. Only full tiles are "
5555 "supported when padding_value is not set");
5565 for (
auto o : ofrs) {
5567 if (llvm::dyn_cast_if_present<Value>(o))
5568 result.push_back(ShapedType::kDynamic);
5580 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5581 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5583 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5584 resultShape[tiledDim.value()] = ShapedType::kDynamic;
5587 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5588 resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5592 if (!outerDimsPerm.empty())
5596 resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5609 for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5611 builder, loc, ceilDivExpr,
5612 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5614 if (!outerDimsPerm.empty())
5616 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5621 innerDimsPos, outerDimsPerm);
5627 for (
unsigned i = 0; i < resultDims.size(); ++i) {
5628 if (ShapedType::isStatic(resultTypeShape[i]))
5637RankedTensorType PackOp::inferPackedTensorType(
5641 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5642 return RankedTensorType::get(resultShape, sourceType.getElementType());
5645MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
5650 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5651 return MemRefType::get(resultShape, sourceType.getElementType());
5666 for (
auto [
index, value] : llvm::enumerate(
5667 llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
5668 if (ShapedType::isDynamic(value))
5669 mixedSizes.push_back(
5670 tensor::DimOp::create(
b, loc, source,
index).getResult());
5672 mixedSizes.push_back(
b.getIndexAttr(value));
5674 for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5675 int64_t dimPos = std::get<0>(it);
5677 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5679 if (!outerDimsPerm.empty())
5682 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5683 auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
5684 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
5691 *
this, innerPermutation, outerPermutation);
5692 Value transposedDest =
5693 createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
5694 metadata.innerDimsPos, metadata.outerDimsPerm);
5695 return PackOp::create(
b, loc, getSource(), transposedDest,
5696 metadata.innerDimsPos, metadata.innerTiles,
5697 getPaddingValue(), metadata.outerDimsPerm);
5700template <
typename OpTy>
5705 if (op.hasPureTensorSemantics())
5708 for (
OpOperand &opOperand : op.getOperation()->getOpOperands()) {
5709 if (!llvm::isa<MemRefType>(opOperand.
get().
getType()))
5712 if (&opOperand == &op.getSourceMutable()) {
5716 }
else if (&opOperand == &op.getDestMutable()) {
5727void PackOp::getEffects(
5733void UnPackOp::getEffects(
5740template <
typename OpTy>
5742 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5743 "applies to only pack or unpack operations");
5744 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5746 : op.getSourceType();
5748 for (
auto [dimDest,
tile] : llvm::zip(
5749 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5751 if (!constTileSize || ShapedType::isDynamic(dimDest))
5758 if (!hasPureTensorSemantics())
5760 if (getPaddingValue())
5775 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5777 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5789 auto packTiles = packOp.getMixedTiles();
5790 auto unPackTiles = unPackOp.getMixedTiles();
5791 if (packTiles.size() != unPackTiles.size())
5793 for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
5802 auto srcType = op.getSourceType();
5803 if (llvm::any_of(op.getInnerDimsPos(),
5804 [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
5806 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5808 return !PackOp::requirePaddingValue(
5809 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5810 op.getOuterDimsPerm(), op.getMixedTiles());
5817 bool changeNeeded =
false;
5818 srcShape.assign(packOp.getSourceType().getShape().begin(),
5819 packOp.getSourceType().getShape().end());
5820 destShape.assign(packOp.getDestType().getShape().begin(),
5821 packOp.getDestType().getShape().end());
5822 llvm::SmallSetVector<int64_t, 4> innerDims;
5823 innerDims.insert_range(packOp.getInnerDimsPos());
5825 if (!packOp.getOuterDimsPerm().empty())
5827 int srcRank = packOp.getSourceRank();
5828 for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
5829 if (innerDims.contains(i))
5833 if (!inverseOuterDimsPerm.empty())
5834 destPos = inverseOuterDimsPerm[srcPos];
5835 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5836 ShapedType::isDynamic(destShape[destPos])) {
5839 int64_t size = srcShape[srcPos];
5840 if (ShapedType::isDynamic(size))
5841 size = destShape[destPos];
5842 srcShape[srcPos] = size;
5843 destShape[destPos] = size;
5844 changeNeeded =
true;
5846 return changeNeeded;
5849LogicalResult PackOp::canonicalize(PackOp packOp,
PatternRewriter &rewriter) {
5851 if (!packOp.hasPureTensorSemantics())
5855 if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5856 if (unPackOp.getSourceType() == packOp.getDestType() &&
5857 !packOp.getPaddingValue() &&
5860 rewriter.
replaceOp(packOp, unPackOp.getSource());
5868 packOp.getPaddingValueMutable().clear();
5877 Value source = packOp.getSource();
5878 if (srcShape != packOp.getSourceType().getShape()) {
5879 auto newSrcType = packOp.getSourceType().clone(srcShape);
5881 tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5883 Value dest = packOp.getDest();
5884 ShapedType originalResultType = packOp.getDestType();
5885 bool needUpdateDestType = (destShape != originalResultType.getShape());
5886 if (needUpdateDestType) {
5887 auto newDestType = packOp.getDestType().clone(destShape);
5889 tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5892 packOp.getSourceMutable().assign(source);
5893 packOp.getDestMutable().assign(dest);
5894 packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
5897 if (needUpdateDestType) {
5899 auto castOp = tensor::CastOp::create(rewriter, loc, originalResultType,
5900 packOp.getResult());
5909template <
typename PackOrUnpackOp>
5911 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5912 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5913 "Function meant for pack/unpack");
5918 int64_t numPackedDims = innerDimsPos.size();
5919 auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5920 if (orderedDims != innerDimsPos) {
5926 int64_t packedRank = packedTensorType.getRank();
5936 return llvm::all_of(
5937 llvm::seq<int64_t>(0, packedRank - numPackedDims),
5938 [&packedShape](
int64_t i) {
return packedShape[i] == 1; });
5941bool PackOp::isLikePad() {
5942 auto packedTensorType =
5943 llvm::cast<ShapedType>((*this)->getResultTypes().front());
5947::mlir::LogicalResult
5948PackOp::fold(FoldAdaptor adaptor,
5950 if (!hasPureTensorSemantics())
5952 std::optional<Attribute> paddingValue;
5953 if (
auto pad = adaptor.getPaddingValue())
5955 if (
OpFoldResult reshapedSource = reshapeConstantSource(
5956 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5957 cast<TensorType>(getDestType()), paddingValue)) {
5958 results.push_back(reshapedSource);
5984 if (!op.hasPureTensorSemantics())
6003 PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
6004 op.getInnerDimsPos(), newMixedTileSizes,
6005 op.getPaddingValue(), op.getOuterDimsPerm());
6006 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6009 Value oldResult = op.getResult();
6010 Value newResult = newOp.getResult();
6013 ? tensor::CastOp::create(rewriter, op->getLoc(),
6014 oldResult.
getType(), newResult)
6027void UnPackOp::getAsmResultNames(
6029 if (!getResults().empty())
6030 setNameFn(getResult(),
"unpack");
6039 Type sourceType, destType, resultType;
6051 if (parser.parseInteger(value))
6053 outerDimsPermVec.push_back(value);
6066 if (parser.parseInteger(value))
6068 innerDimsPosVec.push_back(value);
6080 for (
auto val : staticTilesAttr.
asArrayRef())
6081 staticTiles.push_back(val);
6098 bool isMemRef = llvm::isa<MemRefType>(sourceType);
6101 "pack/unpack requires '->' and destination type");
6105 resultType = destType;
6111 if (!dynamicTiles.empty() &&
6116 result.addAttribute(
"static_inner_tiles",
6118 result.addAttribute(
"inner_dims_pos", innerDimsPos);
6120 result.addAttribute(
"outer_dims_perm", outerDimsPerm);
6123 1, 1, 0,
static_cast<int32_t
>(dynamicTiles.size())};
6124 result.addAttribute(
"operandSegmentSizes",
6128 result.addTypes(resultType);
6134 p <<
" " << getSource();
6136 if (!getOuterDimsPerm().empty()) {
6137 p <<
" outer_dims_perm = [";
6138 llvm::interleaveComma(getOuterDimsPerm(), p);
6142 p <<
" inner_dims_pos = [";
6143 llvm::interleaveComma(getInnerDimsPos(), p);
6146 p <<
" inner_tiles = ";
6149 p <<
" into " << getDest();
6152 {
"static_inner_tiles",
"inner_dims_pos",
6153 "outer_dims_perm",
"operandSegmentSizes"});
6155 p <<
" : " << getSource().getType();
6156 p <<
" -> " << getDest().getType();
6160UnPackOp::reifyResultShapes(
OpBuilder &builder,
6162 if (!hasPureTensorSemantics())
6180 ShapedType destType = getDestType();
6181 int64_t destRank = destType.getRank();
6182 return getSourceType().getShape().take_front(destRank);
6186 auto innerDimsPos = getInnerDimsPos();
6193 if (!outerDimPermInv.empty())
6197 for (
auto index : innerDimsPos)
6198 res.push_back(outerDims[
index]);
6203LogicalResult UnPackOp::verify() {
6208 if (!hasPureTensorSemantics())
6221 assert(innerDimsPos.size() == innerTiles.size() &&
6222 "number of tile sizes specified must match the specified number of "
6223 "original dimensions to be tiled");
6227 build(builder, state, dest.
getType(), source, dest,
6228 outerDimsPerm.empty() ?
nullptr
6246 auto srcType = llvm::cast<RankedTensorType>(source.
getType());
6248 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
6249 if (srcType.isDynamicDim(i))
6250 mixedSizes.push_back(
6251 tensor::DimOp::create(
b, loc, source, i).getResult());
6253 mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
6255 if (!outerDimsPerm.empty()) {
6260 for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
6261 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
6263 auto elemType = srcType.getElementType();
6264 return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
6268 Value transposedSource,
6272 *
this, innerPermutation, outerPermutation);
6273 return UnPackOp::create(
b, loc, transposedSource, getDest(),
6274 metadata.innerDimsPos, metadata.innerTiles,
6275 metadata.outerDimsPerm);
6282 bool changeNeeded =
false;
6283 srcShape.assign(op.getSourceType().getShape().begin(),
6284 op.getSourceType().getShape().end());
6285 destShape.assign(op.getDestType().getShape().begin(),
6286 op.getDestType().getShape().end());
6287 llvm::SmallSetVector<int64_t, 4> innerDims;
6288 innerDims.insert_range(op.getInnerDimsPos());
6290 if (!op.getOuterDimsPerm().empty())
6292 int destRank = op.getDestRank();
6293 for (
auto i : llvm::seq<int64_t>(0, destRank)) {
6294 if (innerDims.contains(i))
6298 if (!inverseOuterDimsPerm.empty())
6299 srcPos = inverseOuterDimsPerm[destPos];
6300 if (ShapedType::isDynamic(srcShape[srcPos]) ==
6301 ShapedType::isDynamic(destShape[destPos])) {
6304 int64_t size = srcShape[srcPos];
6305 if (ShapedType::isDynamic(size))
6306 size = destShape[destPos];
6307 srcShape[srcPos] = size;
6308 destShape[destPos] = size;
6309 changeNeeded =
true;
6311 return changeNeeded;
6314LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
6317 if (!unPackOp.hasPureTensorSemantics())
6321 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
6322 if (packOp.getSourceType() != unPackOp.getDestType())
6324 if (packOp.getPaddingValue() ||
6328 rewriter.
replaceOp(unPackOp, packOp.getSource());
6332 if (
auto dstStyleOp =
6333 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
6334 auto destValue = cast<OpResult>(unPackOp.getDest());
6335 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
6337 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
6341 if (unPackOp->hasOneUse()) {
6342 auto extractSliceUser =
6343 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
6344 if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
6347 auto newDest = tensor::ExtractSliceOp::create(
6348 rewriter, unPackOp->getLoc(), unPackOp.getDest(),
6349 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
6350 extractSliceUser.getMixedStrides());
6352 unPackOp.setDpsInitOperand(0, newDest);
6353 unPackOp.getResult().setType(newDest.
getType());
6355 rewriter.
replaceOp(extractSliceUser, unPackOp);
6364 Value source = unPackOp.getSource();
6365 if (srcShape != unPackOp.getSourceType().getShape()) {
6366 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
6367 source = tensor::CastOp::create(rewriter, loc, newSrcType,
6368 unPackOp.getSource());
6370 Value dest = unPackOp.getDest();
6371 if (destShape != unPackOp.getDestType().getShape()) {
6372 auto newDestType = unPackOp.getDestType().clone(destShape);
6373 dest = tensor::CastOp::create(rewriter, loc, newDestType,
6374 unPackOp.getDest());
6376 UnPackOp newOp = UnPackOp::create(
6377 rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
6378 unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
6380 unPackOp, unPackOp.getResult().
getType(), newOp.getResult());
6387bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
6389 if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
6394 RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
6397 for (
auto [pos, tileSize] :
6398 llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
6399 if (unpackedTypeAfterFold.isDynamicDim(pos))
6401 if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
6403 if (ShapedType::isDynamic(tileSize))
6405 int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
6406 unpackedTypeAfterFold.getDimSize(pos);
6407 if (paddingSize >= tileSize)
6413bool UnPackOp::isLikeUnPad() {
6414 ShapedType packedTensorType = getSourceType();
6418::mlir::LogicalResult
6419UnPackOp::fold(FoldAdaptor adaptor,
6420 ::llvm::SmallVectorImpl<OpFoldResult> &results) {
6422 if (!hasPureTensorSemantics())
6425 if (
OpFoldResult reshapedSource = reshapeConstantSource(
6426 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6427 cast<TensorType>(getResult().
getType()))) {
6428 results.push_back(reshapedSource);
6454 if (!op.hasPureTensorSemantics())
6463 Value sourceTensor = newOperands[0];
6467 rewriter, sourceTensor.
getType(), op.getMixedTiles());
6473 UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6474 newOperands[1], op.getInnerDimsPos(),
6475 newMixedTileSizes, op.getOuterDimsPerm());
6476 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6479 Value oldResult = op.getResult();
6480 Value newResult = newOp.getResult();
6483 ? tensor::CastOp::create(rewriter, op->getLoc(),
6484 oldResult.
getType(), newResult)
6498 utils::IteratorType::reduction, utils::IteratorType::parallel,
6499 utils::IteratorType::parallel, utils::IteratorType::reduction};
6503BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
6507 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
6508 indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
6510 return indexingMaps;
6513bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
6514 ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6517 if (maps.size() != 3)
6526unsigned BatchReduceMatmulOp::getNumRegionArgs() {
return 3; }
6528std::string BatchReduceMatmulOp::getLibraryCallName() {
6534bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6538 return defaultMaps != explicitMaps;
6548bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
6551 "Expected less than 3 result dim expr.");
6552 bool isValid =
false;
6553 enum Indices { batchPos, mPos, nPos, kPos };
6564 : ((expr0.isFunctionOfDim(batchPos) &&
6565 expr1.isFunctionOfDim(kPos)) ||
6566 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6571void BatchReduceMatmulOp::regionBuilder(
6575 emitError() <<
"BatchReduceMatmulOp regionBuilder expects 3 args, got "
6580 "BatchReduceMatmulOp regionBuilder expects 3 args");
6581 RegionBuilderHelper helper(
b, block);
6586 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
6588 helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
6589 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6591 helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
6592 yields.push_back(addVal);
6593 helper.yieldOutputs(yields);
6596ParseResult BatchReduceMatmulOp::parse(
OpAsmParser &parser,
6609 if (!isa<AffineMapAttr>(mapAttr)) {
6611 "expected affine map attribute");
6613 indexingMapsAttr.push_back(mapAttr);
6623 if (indexingMapsAttr.empty()) {
6624 indexingMapsAttr = llvm::map_to_vector(
6625 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
6628 result.addAttribute(
"indexing_maps",
6630 return ::parseNamedStructuredOp(parser,
result,
6631 BatchReduceMatmulOp::getNumRegionArgs(),
6632 BatchReduceMatmulOp::getRegionBuilder());
6637 BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
6640 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6641 p <<
" indexing_maps = [";
6642 llvm::interleaveComma(getIndexingMaps(), p,
6648 "operandSegmentSizes",
"linalg.memoized_indexing_maps",
"indexing_maps"};
6654LogicalResult BatchReduceMatmulOp::verify() {
6657 if (!hasUserDefinedMaps())
6660 for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
6666LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6670void BatchReduceMatmulOp::getEffects(
6673 if (hasPureTensorSemantics())
6689void LinalgDialect::getCanonicalizationPatterns(
6698 return arith::ConstantOp::materialize(builder, value, type, loc);
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Base type for affine expression.
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
AffineMap dropResults(ArrayRef< int64_t > positions) const
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
unsigned getNumDims() const
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
AffineExpr getResult(unsigned idx) const
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
@ Paren
Parens surrounding zero or more operands.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void decreaseIndent()
Decrease indentation.
virtual void increaseIndent()
Increase indentation.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
BlockArgListType getArguments()
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerAttr getIntegerAttr(Type type, int64_t value)
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerAttr getI64IntegerAttr(int64_t value)
StringAttr getStringAttr(const Twine &bytes)
AffineExpr getAffineDimExpr(unsigned position)
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
MLIRContext * getContext() const
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
IRValueT get() const
Return the current value being used by this operand.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
This class represents an operand of an operation.
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
unsigned getResultNumber() const
Returns the number of this result.
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
result_iterator result_begin()
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_iterator result_end()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static DefaultResource * get()
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Block * getParentBlock()
Return the Block in which this Value is defined.
bool hasOneUse() const
Returns true if this value has exactly one use.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
ArrayRef< T > asArrayRef() const
static Attribute parse(AsmParser &parser, Type type)
Specialization of linalg.batch_matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Specialization of linalg.batch_matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static bool classof(Operation *op)
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Specialization of linalg.matmul op that has a transpose map on A.
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
static bool classof(Operation *op)
Specialization of linalg.matmul op that has a transpose map on B.
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
static bool classof(Operation *op)
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
static void getPackUnPackEffectsImpl(OpTy op, SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects)
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType)
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags, SmallVectorImpl< Type > *valueTypes=nullptr, AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Parser hooks for custom directive in assemblyFormat.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
llvm::TypeSwitch< T, ResultT > TypeSwitch
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
llvm::function_ref< Fn > function_ref
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayRef< int64_t > integers, ArrayRef< bool > scalableFlags, TypeRange valueTypes=TypeRange(), AsmParser::Delimiter delimiter=AsmParser::Delimiter::Square)
Printer hooks for custom directive in assemblyFormat.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Fold transpose with transpose.
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
This is the representation of an operand reference.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override