39#include "llvm/ADT/DenseMap.h" 
   40#include "llvm/ADT/STLExtras.h" 
   41#include "llvm/ADT/SetOperations.h" 
   42#include "llvm/ADT/SmallVector.h" 
   43#include "llvm/ADT/StringSet.h" 
   44#include "llvm/ADT/TypeSwitch.h" 
   45#include "llvm/Support/FormatVariadic.h" 
   46#include "llvm/Support/InterleavedRange.h" 
   47#include "llvm/Support/LogicalResult.h" 
   48#include "llvm/Support/MathExtras.h" 
   49#include "llvm/Support/raw_ostream.h" 
   59  auto type = cast<ShapedType>(v.
getType());
 
   60  if (!type.isDynamicDim(dim))
 
   65          .Case<RankedTensorType>([&](RankedTensorType t) -> 
Value {
 
   66            return tensor::DimOp::create(builder, loc, v, dim);
 
   68          .Case<MemRefType>([&](MemRefType t) -> 
Value {
 
   69            return memref::DimOp::create(builder, loc, v, dim);
 
 
   80      .Case<RankedTensorType>([&](RankedTensorType t) -> 
Operation * {
 
   81        return tensor::ExtractSliceOp::create(
b, loc, source, offsets, sizes,
 
   84      .Case<MemRefType>([&](MemRefType type) -> 
Operation * {
 
   85        return memref::SubViewOp::create(
b, loc, source, offsets, sizes,
 
 
   97  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.
getType()))
 
   98    return b.createOrFold<memref::DimOp>(loc, source, dim);
 
   99  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.
getType()))
 
  100    return b.createOrFold<tensor::DimOp>(loc, source, dim);
 
  101  llvm_unreachable(
"Expected MemRefType or TensorType");
 
 
  106  auto shapedType = llvm::cast<ShapedType>(source.
getType());
 
  107  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
 
  109  return b.getIndexAttr(shapedType.getDimSize(dim));
 
 
  132  for (
auto containers : {inputTypes, outputTypes}) {
 
  133    for (
auto t : containers) {
 
  145      opBuilder.
createBlock(®ion, {}, argTypes, argLocs);
 
 
  161                              std::optional<TypeRange> resultTensorTypes,
 
  168  if (!resultTensorTypes)
 
  169    copy_if(outputs.
getTypes(), std::back_inserter(derivedResultTypes),
 
  170            llvm::IsaPred<RankedTensorType>);
 
  178      "operandSegmentSizes",
 
  179      b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
 
  180                              static_cast<int32_t>(outputs.size())}));
 
 
  190                          std::optional<TypeRange> resultTensorTypes,
 
  197  indexingMapsAttrVal =
 
  199        return AffineMapAttr::get(map);
 
  201  state.
addAttribute(
"indexing_maps", 
b.getArrayAttr(indexingMapsAttrVal));
 
  203                           attributes, regionBuilder);
 
 
  207                               std::optional<TypeRange> resultTensorTypes,
 
  214  indexingMapsAttrVal =
 
  216        return AffineMapAttr::get(map);
 
  218  state.
addAttribute(
"indexing_maps", 
b.getArrayAttr(indexingMapsAttrVal));
 
  220                           attributes, regionBuilder);
 
 
  224                                     std::optional<TypeRange> resultTensorTypes,
 
  231  indexingMapsAttrVal =
 
  233        return AffineMapAttr::get(map);
 
  235  state.
addAttribute(
"indexing_maps", 
b.getArrayAttr(indexingMapsAttrVal));
 
  237                           attributes, regionBuilder);
 
 
  246                             bool addOperandSegmentSizes = 
true) {
 
  247  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
 
  276  if (parser.
resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
 
  278      parser.
resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
 
  282  if (addOperandSegmentSizes) {
 
  289    if (
result.propertiesAttr) {
 
  291      attrs.
append(
"operandSegmentSizes",
 
  293                       {static_cast<int32_t>(inputsOperands.size()),
 
  294                        static_cast<int32_t>(outputsOperands.size())}));
 
  297      result.addAttribute(
"operandSegmentSizes",
 
  299                              {static_cast<int32_t>(inputsOperands.size()),
 
  300                               static_cast<int32_t>(outputsOperands.size())}));
 
  303  if (!
result.propertiesAttr) {
 
  304    std::optional<RegisteredOperationName> info =
 
  305        result.name.getRegisteredInfo();
 
  307      if (failed(info->verifyInherentAttrs(
result.attributes, [&]() {
 
  308            return parser.emitError(attrsLoc)
 
  309                   << 
"'" << result.name.getStringRef() << 
"' op ";
 
 
  320    p << 
" ins(" << inputs << 
" : " << inputs.
getTypes() << 
")";
 
  321  if (!outputs.empty())
 
  322    p << 
" outs(" << outputs << 
" : " << outputs.
getTypes() << 
")";
 
 
  333  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
 
  336        llvm::formatv(
"[parseNamedStructuredOpRegion] ods-gen generated " 
  337                      "region expects {0} args, got {1}",
 
  338                      numRegionArgs, inputTypes.size() + outputTypes.size()));
 
  344      opBuilder, region, inputTypes, outputTypes, attrs,
 
 
  363                                          unsigned numRegionArgs,
 
  380  result.addTypes(outputTensorsTypes);
 
  382  std::unique_ptr<Region> region = std::make_unique<Region>();
 
  384                                   outputTypes, 
result.attributes.getAttrs(),
 
  387  result.addRegion(std::move(region));
 
 
  394  if (resultTypes.empty())
 
 
  439class RegionBuilderHelper {
 
  441  RegionBuilderHelper(OpBuilder &builder, 
Block &block)
 
  442      : builder(builder), block(block) {}
 
  445  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
 
  447    if (!isFloatingPoint(arg)) {
 
  449        emitError() << 
"unsupported non numeric type";
 
  452      llvm_unreachable(
"unsupported non numeric type");
 
  454    OpBuilder::InsertionGuard g(builder);
 
  455    builder.setInsertionPointToEnd(&block);
 
  458      return math::ExpOp::create(builder, arg.
getLoc(), arg);
 
  460      return math::LogOp::create(builder, arg.
getLoc(), arg);
 
  462      return math::AbsFOp::create(builder, arg.
getLoc(), arg);
 
  464      return math::CeilOp::create(builder, arg.
getLoc(), arg);
 
  466      return math::FloorOp::create(builder, arg.
getLoc(), arg);
 
  468      return arith::NegFOp::create(builder, arg.
getLoc(), arg);
 
  469    case UnaryFn::reciprocal: {
 
  470      Attribute oneAttr = builder.getOneAttr(arg.
getType());
 
  471      auto one = arith::ConstantOp::create(builder, arg.
getLoc(),
 
  472                                           ::cast<TypedAttr>(oneAttr));
 
  473      return arith::DivFOp::create(builder, arg.
getLoc(), one, arg);
 
  476      return math::RoundOp::create(builder, arg.
getLoc(), arg);
 
  478      return math::SqrtOp::create(builder, arg.
getLoc(), arg);
 
  480      return math::RsqrtOp::create(builder, arg.
getLoc(), arg);
 
  481    case UnaryFn::square:
 
  482      return arith::MulFOp::create(builder, arg.
getLoc(), arg, arg);
 
  484      return math::TanhOp::create(builder, arg.
getLoc(), arg);
 
  486      return math::ErfOp::create(builder, arg.
getLoc(), arg);
 
  489      emitError() << 
"unsupported unary function";
 
  492    llvm_unreachable(
"unsupported unary function");
 
  499  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
 
  501    bool allComplex = isComplex(arg0) && isComplex(arg1);
 
  502    bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
 
  503    bool allInteger = isInteger(arg0) && isInteger(arg1);
 
  506    if (!allComplex && !allFloatingPoint && !allInteger) {
 
  509            << 
"Cannot build binary Linalg operation: expects allComplex, " 
  510               "allFloatingPoint, or allInteger, got " 
  514      llvm_unreachable(
"unsupported non numeric type");
 
  516    OpBuilder::InsertionGuard g(builder);
 
  517    builder.setInsertionPointToEnd(&block);
 
  521        return complex::AddOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  522      if (allFloatingPoint)
 
  523        return arith::AddFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  525        return arith::OrIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  526      return arith::AddIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  529        return complex::SubOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  530      if (allFloatingPoint)
 
  531        return arith::SubFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  534          emitError() << 
"unsupported operation: sub with bools";
 
  537        llvm_unreachable(
"unsupported operation: sub with bools");
 
  539      return arith::SubIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  542        return complex::MulOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  543      if (allFloatingPoint)
 
  544        return arith::MulFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  546        return arith::AndIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  547      return arith::MulIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  550        return complex::DivOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  551      if (allFloatingPoint)
 
  552        return arith::DivFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  555          emitError() << 
"unsupported operation: div with bools";
 
  558        llvm_unreachable(
"unsupported operation: div with bools");
 
  560      return arith::DivSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  561    case BinaryFn::div_unsigned:
 
  562      if (!allInteger || allBool) {
 
  564          emitError() << 
"unsupported operation: unsigned div not on uint";
 
  567        llvm_unreachable(
"unsupported operation: unsigned div not on uint");
 
  569      return arith::DivUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  570    case BinaryFn::max_signed:
 
  572      if (allFloatingPoint)
 
  573        return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  574      return arith::MaxSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  575    case BinaryFn::min_signed:
 
  577      if (allFloatingPoint)
 
  578        return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  579      return arith::MinSIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  580    case BinaryFn::max_unsigned:
 
  582      if (allFloatingPoint)
 
  583        return arith::MaximumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  584      return arith::MaxUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  585    case BinaryFn::min_unsigned:
 
  587      if (allFloatingPoint)
 
  588        return arith::MinimumFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  589      return arith::MinUIOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  591      assert(allFloatingPoint);
 
  592      return math::PowFOp::create(builder, arg0.
getLoc(), arg0, arg1);
 
  595      emitError() << 
"unsupported binary function";
 
  598    llvm_unreachable(
"unsupported binary function");
 
  602  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
 
  606    bool tailFloatingPoint =
 
  607        isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
 
  608    bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
 
  609    OpBuilder::InsertionGuard g(builder);
 
  610    builder.setInsertionPointToEnd(&block);
 
  612    case TernaryFn::select:
 
  613      if (!headBool && !(tailFloatingPoint || tailInteger))
 
  614        llvm_unreachable(
"unsupported non numeric type");
 
  615      return arith::SelectOp::create(builder, arg0.
getLoc(), arg0, arg1, arg2);
 
  618      emitError() << 
"unsupported ternary function";
 
  621    llvm_unreachable(
"unsupported ternary function");
 
  625  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
 
  628    case TypeFn::cast_signed:
 
  629      return cast(toType, operand, 
false);
 
  630    case TypeFn::cast_unsigned:
 
  631      return cast(toType, operand, 
true);
 
  634      emitError() << 
"unsupported type conversion function";
 
  637    llvm_unreachable(
"unsupported type conversion function");
 
  641    OpBuilder::InsertionGuard g(builder);
 
  642    builder.setInsertionPointToEnd(&block);
 
  643    Location loc = builder.getUnknownLoc();
 
  644    YieldOp::create(builder, loc, values);
 
  647  Value constant(
const std::string &value) {
 
  648    OpBuilder::InsertionGuard g(builder);
 
  649    builder.setInsertionPointToEnd(&block);
 
  650    Location loc = builder.getUnknownLoc();
 
  651    Attribute valueAttr = 
parseAttribute(value, builder.getContext());
 
  652    return arith::ConstantOp::create(builder, loc,
 
  653                                     ::cast<TypedAttr>(valueAttr));
 
  656  Value index(int64_t dim) {
 
  657    OpBuilder::InsertionGuard g(builder);
 
  658    builder.setInsertionPointToEnd(&block);
 
  659    return IndexOp::create(builder, builder.getUnknownLoc(), dim);
 
  662  Type getIntegerType(
unsigned width) {
 
  663    return IntegerType::get(builder.getContext(), width);
 
  666  Type getFloat32Type() { 
return Float32Type::get(builder.getContext()); }
 
  667  Type getFloat64Type() { 
return Float64Type::get(builder.getContext()); }
 
  674  Value cast(Type toType, Value operand, 
bool isUnsignedCast) {
 
  675    OpBuilder::InsertionGuard g(builder);
 
  676    builder.setInsertionPointToEnd(&block);
 
  677    auto loc = operand.
getLoc();
 
  678    if (isa<UnknownLoc>(loc)) {
 
  688  bool isComplex(Value value) {
 
  689    return llvm::isa<ComplexType>(value.
getType());
 
  691  bool isFloatingPoint(Value value) {
 
  692    return llvm::isa<FloatType>(value.
getType());
 
  694  bool isInteger(Value value) {
 
  695    return llvm::isa<IntegerType>(value.
getType());
 
  711  using OpRewritePattern<CopyOp>::OpRewritePattern;
 
  712  LogicalResult matchAndRewrite(CopyOp copyOp,
 
  713                                PatternRewriter &rewriter)
 const override {
 
  714    if (copyOp.getInputs() != copyOp.getOutputs())
 
  716    if (copyOp.hasPureBufferSemantics())
 
  719      rewriter.
replaceOp(copyOp, copyOp.getInputs());
 
  729  results.
add<EraseSelfCopy>(context);
 
  742template <
typename TensorReshapeOp>
 
  744  using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
 
  745  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
 
  746                                PatternRewriter &rewriter)
 const override {
 
  747    auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
 
  751    Location loc = oldFill.getLoc();
 
  752    TensorReshapeOp newInit;
 
  753    if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
 
  755      newInit = TensorReshapeOp::create(
 
  756          rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
 
  757          reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
 
  758          reshapeOp.getStaticOutputShape());
 
  760      newInit = TensorReshapeOp::create(
 
  761          rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
 
  762          reshapeOp.getReassociation());
 
  775  LogicalResult matchAndRewrite(tensor::PadOp padOp,
 
  776                                PatternRewriter &rewriter)
 const override {
 
  777    auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
 
  783    Value padValue = padOp.getConstantPaddingValue();
 
  784    if (!padValue || fillOp.value() != padValue)
 
  790          padOp, 
"failed to reify tensor.pad op result shape");
 
  793        tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
 
  794                                padOp.getResultType().getElementType());
 
  796        FillOp::create(rewriter, fillOp.getLoc(), 
ValueRange{padValue},
 
  799    if (
replacement.getType() != padOp.getResultType()) {
 
  800      replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
 
  811struct FoldInsertPadIntoFill : 
public OpRewritePattern<tensor::InsertSliceOp> {
 
  814  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
 
  815                                PatternRewriter &rewriter)
 const override {
 
  816    auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
 
  820    if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
 
  825    Value firstDest = insertOp.getDest();
 
  826    while (
auto prevOp = firstDest.
getDefiningOp<tensor::InsertSliceOp>()) {
 
  827      if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
 
  832      bool disjoint = 
false;
 
  833      for (
int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
 
  836        if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
 
  837            insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
 
  838            prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
 
  842        int64_t prevStart = prevOp.getStaticOffset(i);
 
  843        int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
 
  844                                          prevOp.getStaticStride(i);
 
  845        int64_t nextStart = insertOp.getStaticOffset(i);
 
  846        int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
 
  847                                          insertOp.getStaticStride(i);
 
  848        if (prevEnd < nextStart || nextEnd < prevStart) {
 
  856      firstDest = prevOp.getDest();
 
  867    Value padValue = srcPadOp.getConstantPaddingValue();
 
  868    if (!padValue || dstFillOp.value() != padValue)
 
  871    SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
 
  872    SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
 
  874    Location loc = insertOp.getLoc();
 
  877    AffineExpr sym0, sym1;
 
  883    SmallVector<OpFoldResult, 4> newOffsets;
 
  884    for (
const auto &p : llvm::zip(lowPads, oldOffsets)) {
 
  886          rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
 
  889    RankedTensorType srcPadType = srcPadOp.getSourceType();
 
  890    SmallVector<OpFoldResult, 4> newSizes;
 
  891    for (
int i = 0, e = srcPadType.getRank(); i < e; ++i) {
 
  892      if (srcPadType.isDynamicDim(i)) {
 
  894            tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
 
  897        newSizes.push_back(rewriter.
getIndexAttr(srcPadType.getDimSize(i)));
 
  902        insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
 
  903        newSizes, insertOp.getMixedStrides());
 
  909struct FoldFillWithTensorExtract : 
public OpRewritePattern<tensor::ExtractOp> {
 
  911  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
 
  913  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
 
  914                                PatternRewriter &rewriter)
 const override {
 
  917    auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
 
  922    Value extractedScalar = fillOp.getInputs()[0];
 
  925    rewriter.
replaceOp(extractOp, extractedScalar);
 
  933static FailureOr<FillOp> foldFillPackIntoFillOp(
RewriterBase &rewriter,
 
  934                                                linalg::PackOp packOp) {
 
  935  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
 
  939  if (
auto paddingValue = packOp.getPaddingValue())
 
  943  Value packOpDest = packOp.getDest();
 
  947  return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
 
  954  FoldFillWithPack(MLIRContext *context)
 
  955      : OpRewritePattern<linalg::PackOp>(context) {}
 
  957  LogicalResult matchAndRewrite(linalg::PackOp packOp,
 
  958                                PatternRewriter &rewriter)
 const override {
 
  959    auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
 
  962    rewriter.
replaceOp(packOp, fillOp.value().result());
 
  969  using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
 
  971  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
 
  972                                PatternRewriter &rewriter)
 const override {
 
  973    if (
auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
 
  976                                          copyOp.getOutputs());
 
  979    if (
auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
 
  981                                                  fillOp.getOutputs());
 
  990  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
 
  992  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
 
  993                                PatternRewriter &rewriter)
 const override {
 
  994    if (
auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
 
  996          transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
 
  997          transposeOp.getDpsInitOperand(0)->get());
 
 1009  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
 
 1010                                PatternRewriter &rewriter)
 const override {
 
 1011    auto concatOperands = concatOp.getInputs();
 
 1012    if (concatOperands.empty()) {
 
 1016    auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
 
 1021    OpFoldResult firstFillVal =
 
 1024    SmallVector<Value> allOuts;
 
 1025    allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
 
 1027    auto isDefinedByCompatibleFillOp = [&](Value v) -> 
bool {
 
 1028      auto fillOp = v.getDefiningOp<linalg::FillOp>();
 
 1033      OpFoldResult fillVal =
 
 1035      if (fillVal != firstFillVal)
 
 1038      allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
 
 1041    if (!llvm::all_of(concatOperands.drop_front(),
 
 1042                      isDefinedByCompatibleFillOp)) {
 
 1044          concatOp, 
"not all operands are defined by a compatible fill op");
 
 1047    Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
 
 1048                                                concatOp.getDim(), allOuts);
 
 1050        concatOp, firstFillOp.getDpsInputOperand(0)->
get(), outsConcat);
 
 1059  results.
add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
 
 1060              FoldFillWithPack, FoldFillWithPad,
 
 1061              FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
 
 1062              FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
 
 1063              FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
 
 1076  for (
ValueRange container : {inputs, outputs}) {
 
 1077    for (
Value v : container) {
 
 1078      Type t = v.getType();
 
 1079      blockArgTypes.push_back(
 
 1081      blockArgLocs.push_back(v.getLoc());
 
 1087      builder.
createBlock(®ion, region.
end(), blockArgTypes, blockArgLocs);
 
 
 1091void GenericOp::getAsmBlockArgumentNames(
Region ®ion,
 
 1093  for (Value v : getRegionInputArgs())
 
 1095  for (Value v : getRegionOutputArgs())
 
 1096    setNameFn(v, 
"out");
 
 1099void GenericOp::build(
 
 1100    OpBuilder &builder, OperationState &
result, 
TypeRange resultTensorTypes,
 
 1102    ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
 
 1104    ArrayRef<NamedAttribute> attributes) {
 
 1105  build(builder, 
result, resultTensorTypes, inputs, outputs, indexingMaps,
 
 1106        iteratorTypes, doc, libraryCall);
 
 1107  result.addAttributes(attributes);
 
 1110                       inputs, outputs, bodyBuild);
 
 1113void GenericOp::build(
 
 1114    OpBuilder &builder, OperationState &
result, 
TypeRange resultTensorTypes,
 
 1116    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
 
 1117    StringRef libraryCall,
 
 1119    ArrayRef<NamedAttribute> attributes) {
 
 1120  build(builder, 
result, resultTensorTypes, inputs, outputs,
 
 1124            [&](utils::IteratorType iter) -> mlir::Attribute {
 
 1125              return IteratorTypeAttr::get(builder.getContext(), iter);
 
 1128        libraryCall.empty() ? StringAttr() : builder.
getStringAttr(libraryCall),
 
 1129        bodyBuild, attributes);
 
 1132void GenericOp::build(
 
 1134    ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
 
 1135    ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
 
 1136    StringRef libraryCall,
 
 1138    ArrayRef<NamedAttribute> attributes) {
 
 1140        iteratorTypes, doc, libraryCall, bodyBuild, attributes);
 
 1143void GenericOp::build(
 
 1145    ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
 
 1146    ArrayRef<utils::IteratorType> iteratorTypes,
 
 1148    ArrayRef<NamedAttribute> attributes) {
 
 1149  build(builder, 
result, inputs, outputs, indexingMaps, iteratorTypes,
 
 1151        "", bodyBuild, attributes);
 
 1154void GenericOp::build(
 
 1155    OpBuilder &builder, OperationState &
result, 
TypeRange resultTensorTypes,
 
 1157    ArrayRef<utils::IteratorType> iteratorTypes,
 
 1159    ArrayRef<NamedAttribute> attributes) {
 
 1160  build(builder, 
result, resultTensorTypes, inputs, outputs, indexingMaps,
 
 1163        "", bodyBuild, attributes);
 
 1166void GenericOp::print(OpAsmPrinter &p) {
 
 1170  auto genericAttrNames = linalgTraitAttrNames();
 
 1172  llvm::StringSet<> genericAttrNamesSet;
 
 1173  genericAttrNamesSet.insert_range(genericAttrNames);
 
 1174  SmallVector<NamedAttribute, 8> genericAttrs;
 
 1175  for (
auto attr : (*this)->getAttrs()) {
 
 1176    if (attr.getName() == getIteratorTypesAttrName()) {
 
 1177      auto iteratorTypes =
 
 1178          llvm::cast<ArrayAttr>(attr.getValue())
 
 1179              .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
 
 1184      SmallVector<Attribute> iteratorTypeNames =
 
 1185          llvm::to_vector(llvm::map_range(
 
 1186              iteratorTypes, [&](utils::IteratorType t) -> Attribute {
 
 1187                return StringAttr::get(
getContext(), stringifyIteratorType(t));
 
 1190      genericAttrs.emplace_back(
 
 1191          getIteratorTypesAttrName(),
 
 1192          ArrayAttr::get(
getContext(), iteratorTypeNames));
 
 1193    } 
else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
 
 1194      genericAttrs.push_back(attr);
 
 1197  if (!genericAttrs.empty()) {
 
 1198    auto genericDictAttr = DictionaryAttr::get(
getContext(), genericAttrs);
 
 1199    p << genericDictAttr;
 
 1205  genericAttrNames.push_back(
"operandSegmentSizes");
 
 1206  genericAttrNamesSet.insert(genericAttrNames.back());
 
 1208  bool hasExtraAttrs = 
false;
 
 1209  for (NamedAttribute n : (*this)->getAttrs()) {
 
 1210    if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
 
 1213  if (hasExtraAttrs) {
 
 1220  if (!getRegion().empty()) {
 
 1229ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 1230  DictionaryAttr dictAttr;
 
 1238  result.attributes.assign(dictAttr.getValue().begin(),
 
 1239                           dictAttr.getValue().end());
 
 1245  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
 
 1246      result.attributes.get(getIteratorTypesAttrName(
result.name)));
 
 1247  if (!iteratorTypes) {
 
 1248    return parser.
emitError(attributeLocation)
 
 1249           << 
"expected " << getIteratorTypesAttrName(
result.name)
 
 1250           << 
" array attribute";
 
 1253  SmallVector<Attribute> iteratorTypeAttrs;
 
 1255  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
 
 1256    auto maybeIteratorType = utils::symbolizeIteratorType(s);
 
 1257    if (!maybeIteratorType.has_value())
 
 1259             << 
"unexpected iterator_type (" << s << 
")";
 
 1261    iteratorTypeAttrs.push_back(
 
 1262        IteratorTypeAttr::get(parser.
getContext(), maybeIteratorType.value()));
 
 1264  result.attributes.set(getIteratorTypesAttrName(
result.name),
 
 1268  SmallVector<Type, 1> inputTypes, outputTypes;
 
 1278  std::unique_ptr<Region> region = std::make_unique<Region>();
 
 1281  result.addRegion(std::move(region));
 
 1287  SmallVector<Type, 1> outputTensorsTypes;
 
 1290  result.addTypes(outputTensorsTypes);
 
 1298    LinalgOp linalgOp) {
 
 1299  for (
auto [
index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
 
 1300    if (!llvm::isa<MemRefType>(operand.
getType()))
 
 1302    effects.emplace_back(
 
 1307  for (
OpOperand &operand : linalgOp.getDpsInitsMutable()) {
 
 1308    if (!llvm::isa<MemRefType>(operand.get().
getType()))
 
 1310    if (linalgOp.payloadUsesValueFromOperand(&operand)) {
 
 
 1321void GenericOp::getEffects(
 
 1322    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 1331  if (!linalgOp.hasPureTensorSemantics())
 
 
 1341LogicalResult GenericOp::verify() { 
return success(); }
 
 1351template <
typename OpTy>
 
 1352struct EraseIdentityLinalgOp : 
public OpRewritePattern<OpTy> {
 
 1353  using OpRewritePattern<OpTy>::OpRewritePattern;
 
 1355  LogicalResult matchAndRewrite(OpTy linalgOp,
 
 1356                                PatternRewriter &rewriter)
 const override {
 
 1358    if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
 
 1363    Block &body = linalgOp->getRegion(0).front();
 
 1364    if (!llvm::hasSingleElement(body))
 
 1366    auto yieldOp = dyn_cast<linalg::YieldOp>(body.
getTerminator());
 
 1371    if (linalgOp.hasPureBufferSemantics()) {
 
 1372      if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
 
 1373          linalgOp.getDpsInputOperand(0)->get() !=
 
 1374              linalgOp.getDpsInitOperand(0)->get()) {
 
 1376            linalgOp, 
"expected single input and output to be the same value");
 
 1379      auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
 
 1380      if (!yieldArg || yieldArg.getOwner() != &body) {
 
 1382                                           "cannot fold fill-like op");
 
 1389    if (!linalgOp.hasPureTensorSemantics()) {
 
 1391          linalgOp, 
"mixed semantics is not supported yet");
 
 1396    SmallVector<Value> returnedArgs;
 
 1397    for (
const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
 
 1398      auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
 
 1399      if (!yieldArg || yieldArg.getOwner() != &body)
 
 1401      unsigned argumentNumber = yieldArg.getArgNumber();
 
 1402      Value returnedArg = linalgOp->getOperand(argumentNumber);
 
 1403      Type resultType = linalgOp->getResult(yieldVal.index()).getType();
 
 1406      Type returnType = returnedArg.
getType();
 
 1407      if (returnType != resultType) {
 
 1412          returnedArg = sparse_tensor::ConvertOp::create(
 
 1413              rewriter, linalgOp.getLoc(), resultType, returnedArg);
 
 1415          if (!tensor::CastOp::areCastCompatible(returnedArg.
getType(),
 
 1418          returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
 
 1419                                               resultType, returnedArg);
 
 1422      returnedArgs.push_back(returnedArg);
 
 1425    if (returnedArgs.size() != linalgOp->getNumResults())
 
 1427    rewriter.
replaceOp(linalgOp, returnedArgs);
 
 1434void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 1435                                            MLIRContext *context) {
 
 1436  results.
add<EraseIdentityLinalgOp<GenericOp>>(context);
 
 1439LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 
 1458  for (
Type outputType : outputTypes) {
 
 1459    if (llvm::isa<RankedTensorType>(outputType))
 
 1460      result.addTypes(outputType);
 
 1464  if (parseAttrsFn && failed(parseAttrsFn(parser, 
result.attributes)))
 
 
 1473void MapOp::getAsmBlockArgumentNames(Region ®ion,
 
 1475  for (Value v : getRegionInputArgs())
 
 1477  for (Value v : getRegionOutputArgs())
 
 1478    setNameFn(v, 
"init");
 
 1481void MapOp::getAsmResultNames(
function_ref<
void(Value, StringRef)> setNameFn) {
 
 1482  if (!getResults().empty())
 
 1483    setNameFn(getResults().front(), 
"mapped");
 
 1489    ArrayRef<NamedAttribute> attributes) {
 
 1491  result.addAttributes(attributes);
 
 1494  Type initType = init.
getType();
 
 1495  if (llvm::isa<RankedTensorType>(initType))
 
 1496    result.addTypes(initType);
 
 1500                       inputs, {init}, bodyBuild);
 
 1507                                 bool initFirst = 
false, 
bool mapInit = 
true) {
 
 1511  b.setInsertionPointToStart(&block);
 
 1512  for (
auto &operand : operands) {
 
 1514        llvm::cast<ShapedType>(operand.
getType()).getElementType(),
 
 1522      payloadOpOperands.push_back(block.
getArguments().back());
 
 1523    for (
const auto &arg : block.
getArguments().drop_back())
 
 1524      payloadOpOperands.push_back(arg);
 
 1533      TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
 
 
 1539ParseResult MapOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 1540  std::optional<OperationName> payloadOpName;
 
 1541  NamedAttrList payloadOpAttrs;
 
 1544    if (
failed(operationName))
 
 1548    payloadOpName = operationName.value();
 
 1556  if (payloadOpName.has_value()) {
 
 1557    if (!
result.operands.empty())
 
 1559                           payloadOpAttrs, ArrayRef(
result.operands), 
false,
 
 1564    SmallVector<OpAsmParser::Argument> regionArgs;
 
 1569    Region *body = 
result.addRegion();
 
 1577                            bool mapInit = 
true) {
 
 1579  if (initFirst && !mapInit)
 
 1603    for (
const auto &[operand, bbArg] :
 
 1605      if (bbArg != operand)
 
 1609    for (
const auto &[operand, bbArg] :
 
 1612      if (bbArg != operand)
 
 1619  return yieldOp.getNumOperands() == 1 &&
 
 1620         yieldOp.getOperand(0).getDefiningOp() &&
 
 1621         yieldOp.getOperand(0).getDefiningOp() == &payload;
 
 
 1626  std::string attrToElide;
 
 1628  for (
const auto &attr : payloadOp->
getAttrs()) {
 
 1630        llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
 
 1631    if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
 
 1632      attrToElide = attr.getName().str();
 
 1633      elidedAttrs.push_back(attrToElide);
 
 
 1641void MapOp::print(OpAsmPrinter &p) {
 
 1642  Block *mapper = getBody();
 
 1652  if (!useShortForm) {
 
 1658                          [&](
auto arg) { p.printRegionArgument(arg); });
 
 1666LogicalResult MapOp::verify() {
 
 1667  auto *bodyBlock = getBody();
 
 1668  auto blockArgs = bodyBlock->getArguments();
 
 1672  if (getInputs().size() + 1 != blockArgs.size())
 
 1673    return emitOpError() << 
"expects number of operands to match the arity of " 
 1675                         << getInputs().size() + 1 << 
" and " 
 1676                         << blockArgs.size();
 
 1679  for (
const auto &[bbArgType, inputArg] :
 
 1680       llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
 
 1681    auto inputElemType =
 
 1682        llvm::cast<ShapedType>(inputArg.getType()).getElementType();
 
 1683    if (bbArgType != inputElemType) {
 
 1684      return emitOpError() << 
"expected element type of input " << inputElemType
 
 1685                           << 
" to match bbArg type " << bbArgType;
 
 1690  auto outputShape = getInit().getType().getShape();
 
 1691  for (Type inputArgType : 
TypeRange{getInputs()}) {
 
 1692    auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
 
 1693    if (inputElemShape != outputShape) {
 
 1694      return emitOpError() << 
"expected shape of input (" << inputElemShape
 
 1695                           << 
") to match shape of output (" << outputShape
 
 1703SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
 
 1704  int64_t rank = getInit().getType().getRank();
 
 1705  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 
 1710  int64_t rank = getInit().getType().getRank();
 
 1711  int64_t numIndexingMaps = getOperands().size();
 
 1716void MapOp::getEffects(
 
 1717    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 1730void ReduceOp::getAsmBlockArgumentNames(Region ®ion,
 
 1732  for (Value v : getRegionInputArgs())
 
 1734  for (Value v : getRegionOutputArgs())
 
 1735    setNameFn(v, 
"init");
 
 1738void ReduceOp::getAsmResultNames(
 
 1740  if (!getResults().empty())
 
 1741    setNameFn(getResults().front(), 
"reduced");
 
 1744void ReduceOp::build(
 
 1746    ValueRange inits, ArrayRef<int64_t> dimensions,
 
 1748    ArrayRef<NamedAttribute> attributes) {
 
 1750  result.addAttributes(attributes);
 
 1753  for (Value init : inits) {
 
 1754    Type initType = init.
getType();
 
 1755    if (llvm::isa<RankedTensorType>(initType))
 
 1756      result.addTypes(initType);
 
 1761                       inputs, inits, bodyBuild);
 
 1764SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
 
 1766      llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
 
 1767  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
 
 1768                                                 utils::IteratorType::parallel);
 
 1769  for (int64_t reductionDim : getDimensions())
 
 1770    iteratorTypes[reductionDim] = utils::IteratorType::reduction;
 
 1771  return iteratorTypes;
 
 1776      llvm::cast<ShapedType>(getInputs()[0].
getType()).getRank();
 
 1777  SmallVector<AffineMap> affineMaps(
 
 1780  AffineMap resultMap =
 
 1783  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
 
 1784    affineMaps.push_back(resultMap);
 
 1785  return Builder(
getContext()).getAffineMapArrayAttr(affineMaps);
 
 1788void ReduceOp::getEffects(
 
 1789    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 1800                                          StringRef attributeName) {
 
 
 1808ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 1809  std::optional<OperationName> payloadOpName;
 
 1810  NamedAttrList payloadOpAttrs;
 
 1813    if (
failed(operationName))
 
 1817    payloadOpName = operationName.value();
 
 1823          parser, 
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
 
 1828  if (payloadOpName.has_value()) {
 
 1830                         ArrayRef(
result.operands), 
true);
 
 1832    SmallVector<OpAsmParser::Argument> regionArgs;
 
 1838    Region *body = 
result.addRegion();
 
 1848  p << 
' ' << attributeName << 
" = [" << attributeValue << 
"] ";
 
 
 1851void ReduceOp::print(OpAsmPrinter &p) {
 
 1852  Block *mapper = getBody();
 
 1861  if (!useShortForm) {
 
 1867                          [&](
auto arg) { p.printRegionArgument(arg); });
 
 1875LogicalResult ReduceOp::verify() {
 
 1876  ArrayRef<int64_t> dimensionsRef = getDimensions();
 
 1878  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
 
 1881      return emitOpError() << 
"expects all inputs to have the same shapes. " 
 1882                              "Shape at input-index " 
 1884                           << 
" is not equal to the shape at input-index 0.";
 
 1887  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
 
 1890      return emitOpError() << 
"expects all outputs to have the same shapes. " 
 1891                              "Shape at output-index " 
 1893                           << 
" is not equal to the shape at output-index 0.";
 
 1896  auto inputType = llvm::cast<ShapedType>(getInputs()[0].
getType());
 
 1897  auto initType = llvm::cast<ShapedType>(getInits()[0].
getType());
 
 1900  for (int64_t dimension : dimensionsRef) {
 
 1901    if (dimension < 0 || dimension >= inputType.getRank()) {
 
 1903             << 
"dimensions for reduction should be in the range [0, " 
 1904             << inputType.getRank() - 1 << 
"].";
 
 1906    dimensionsToReduce.insert(dimension);
 
 1909  auto inputDims = inputType.getShape();
 
 1910  auto initDims = initType.getShape();
 
 1913  SmallVector<int64_t> reducedInputDims;
 
 1914  for (
const auto &en : llvm::enumerate(inputDims)) {
 
 1915    if (!dimensionsToReduce.count(en.index()))
 
 1916      reducedInputDims.push_back(en.value());
 
 1919  if (reducedInputDims.size() != 
static_cast<size_t>(initType.getRank())) {
 
 1920    return emitOpError() << 
"number of dimensions after reduction " 
 1921                         << reducedInputDims.size()
 
 1922                         << 
" doesn't match the init rank " 
 1923                         << initType.getRank();
 
 1926  if (reducedInputDims != initDims)
 
 1927    return emitOpError() << 
"init dimensions [" << initDims
 
 1928                         << 
"] doesn't match input dimensions after reduction [" 
 1929                         << reducedInputDims << 
"]";
 
 1931  Block *block = getBody();
 
 1934           << 
"mismatching number of operands and block arguments";
 
 1937  for (
auto [input, bbArg] : llvm::zip(getInputs(), block->
getArguments())) {
 
 1938    Type inputElementType =
 
 1939        llvm::cast<ShapedType>(input.getType()).getElementType();
 
 1940    if (inputElementType != bbArg.getType())
 
 1942             << 
"input element type " << inputElementType
 
 1943             << 
" does not match corresponding block argument type " 
 1948  for (
auto [output, bbArg] : llvm::zip(
 
 1949           getDpsInits(), block->
getArguments().take_back(getNumDpsInits()))) {
 
 1950    auto outputElementType =
 
 1951        llvm::cast<ShapedType>(output.getType()).getElementType();
 
 1952    if (outputElementType != bbArg.getType())
 
 1954             << 
"output element type " << outputElementType
 
 1955             << 
" does not match corresponding block argument type " 
 1971                         linalg::YieldOp::create(
b, loc, args[0]);
 
 
 1975void TransposeOp::build(::mlir::OpBuilder &builder,
 
 1976                        ::mlir::OperationState &
result, Value input, Value init,
 
 1978                        ArrayRef<NamedAttribute> attributes) {
 
 1979  result.addOperands(input);
 
 1980  result.addOperands(init);
 
 1981  result.addAttribute(getPermutationAttrName(
result.name), permutation);
 
 1982  result.addAttributes(attributes);
 
 1985  Type initType = init.
getType();
 
 1986  if (llvm::isa<RankedTensorType>(initType))
 
 1987    result.addTypes(initType);
 
 1993void TransposeOp::build(::mlir::OpBuilder &builder,
 
 1994                        ::mlir::OperationState &
result, Value input, Value init,
 
 1995                        ArrayRef<int64_t> permutation,
 
 1996                        ArrayRef<NamedAttribute> attributes) {
 
 2001ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 2003          parser, 
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
 
 2015void TransposeOp::getAsmResultNames(
 
 2017  if (!getResults().empty())
 
 2018    setNameFn(getResults().front(), 
"transposed");
 
 2021void TransposeOp::print(OpAsmPrinter &p) {
 
 2027LogicalResult TransposeOp::verify() {
 
 2028  ArrayRef<int64_t> permutationRef = getPermutation();
 
 2033  auto inputType = getInput().getType();
 
 2034  auto initType = getInit().getType();
 
 2036  int64_t rank = inputType.getRank();
 
 2038  if (rank != initType.getRank())
 
 2040                         << 
" does not match init rank " << initType.getRank();
 
 2042  if (rank != 
static_cast<int64_t
>(permutationRef.size()))
 
 2043    return emitOpError() << 
"size of permutation " << permutationRef.size()
 
 2044                         << 
" does not match the argument rank " << rank;
 
 2046  auto inputDims = inputType.getShape();
 
 2047  auto initDims = initType.getShape();
 
 2049  for (int64_t i = 0; i < rank; ++i) {
 
 2050    int64_t inputDim = inputDims[permutationRef[i]];
 
 2051    int64_t initDim = initDims[i];
 
 2053    if (inputDim != initDim) {
 
 2054      return emitOpError() << 
"dim(result, " << i << 
") = " << initDim
 
 2055                           << 
" doesn't match dim(input, permutation[" << i
 
 2056                           << 
"]) = " << inputDim;
 
 2063SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
 
 2064  int64_t rank = getInit().getType().getRank();
 
 2065  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 
 2068ArrayAttr TransposeOp::getIndexingMaps() {
 
 2070  int64_t rank = getInit().getType().getRank();
 
 2073           llvm::to_vector_of<unsigned>(getPermutation()), 
getContext())),
 
 2077void TransposeOp::getEffects(
 
 2078    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 2087LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
 
 2088                                SmallVectorImpl<OpFoldResult> &
result) {
 
 2090  if (!isa<TensorType>(getInput().
getType()))
 
 2094  if (getPermutation().size() == 0) {
 
 2095    result.push_back(getInput());
 
 2100    result.push_back(getInput());
 
 2113    auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
 
 2114    if (!defTransposeOp)
 
 2119    foldedPerms.reserve(perms.size());
 
 2121      foldedPerms.push_back(defPerms[perm]);
 
 2124        transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
 
 
 
 2138    Value input = transposeOp.getInput();
 
 2139    BroadcastOp broadcastOp = input.
getDefiningOp<BroadcastOp>();
 
 2150    unsigned dimensionSize = dimensions.size();
 
 2151    for (
unsigned i = 0; i < dimensionSize; ++i)
 
 2152      resultDimensions.push_back(invertPerm[dimensions[i]]);
 
 2155    Value broadcastInput = broadcastOp.getInput();
 
 2156    Location loc = transposeOp.getLoc();
 
 2159    auto broadcastInputTy =
 
 2160        mlir::cast<RankedTensorType>(broadcastInput.
getType());
 
 2161    unsigned inputRank = broadcastInputTy.getRank();
 
 2162    for (
unsigned i = 0; i < inputRank; ++i) {
 
 2163      if (broadcastInputTy.isDynamicDim(i)) {
 
 2164        dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
 
 2167        dims.push_back(IntegerAttr::get(IndexType::get(ctx),
 
 2168                                        broadcastInputTy.getDimSize(i)));
 
 2173    Value transposeInit = tensor::EmptyOp::create(
 
 2174        rewriter, transposeOp.getLoc(), transposeResultShapes,
 
 2175        broadcastInputTy.getElementType());
 
 2178    Value transposeResult =
 
 2179        TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
 
 2180                            transposeInit, resultPerms)
 
 2183        transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
 
 
 
 2188void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 2189                                              MLIRContext *context) {
 
 2190  results.
add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
 
 2197void BroadcastOp::build(::mlir::OpBuilder &builder,
 
 2198                        ::mlir::OperationState &
result, Value input, Value init,
 
 2200                        ArrayRef<NamedAttribute> attributes) {
 
 2201  result.addOperands(input);
 
 2202  result.addOperands(init);
 
 2203  result.addAttribute(getDimensionsAttrName(
result.name), dimensions);
 
 2204  result.addAttributes(attributes);
 
 2207  Type initType = init.
getType();
 
 2208  if (llvm::isa<RankedTensorType>(initType))
 
 2209    result.addTypes(initType);
 
 2215void BroadcastOp::build(::mlir::OpBuilder &builder,
 
 2216                        ::mlir::OperationState &
result, Value input, Value init,
 
 2217                        ArrayRef<int64_t> dimensions,
 
 2218                        ArrayRef<NamedAttribute> attributes) {
 
 2223ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 2225          parser, 
result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
 
 2237void BroadcastOp::getAsmResultNames(
 
 2239  if (!getResults().empty())
 
 2240    setNameFn(getResults().front(), 
"broadcasted");
 
 2243void BroadcastOp::print(OpAsmPrinter &p) {
 
 2249LogicalResult BroadcastOp::verify() {
 
 2250  ArrayRef<int64_t> dimensionsRef = getDimensions();
 
 2252  auto inputType = getInput().getType();
 
 2253  auto initType = getInit().getType();
 
 2255  int64_t inputRank = inputType.getRank();
 
 2256  int64_t initRank = initType.getRank();
 
 2258  auto inputShape = inputType.getShape();
 
 2259  auto initShape = initType.getShape();
 
 2261  if ((
size_t)inputRank + dimensionsRef.size() != (
size_t)initRank)
 
 2262    return emitOpError() << 
"input rank plus added dimensions does not " 
 2263                            "match init rank. input rank: " 
 2265                         << 
", dimensions size: " << dimensionsRef.size()
 
 2266                         << 
", init rank: " << initRank;
 
 2268  for (
const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
 
 2269    if (dim < 0 || dim >= initRank)
 
 2271                           << 
" is out of range. expected range: [0, " 
 2272                           << initRank - 1 << 
"], got: " << dim;
 
 2276  SmallVector<int64_t> dimMap;
 
 2277  for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
 
 2278    if (!llvm::is_contained(dimensionsRef, dim))
 
 2279      dimMap.push_back(dim);
 
 2282  for (
const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
 
 2285    if (inputShape[inputDimIdx] != initShape[initDimIdx])
 
 2286      return emitOpError() << 
"input dim " << inputDimIdx
 
 2287                           << 
" should match init dim " << initDimIdx
 
 2288                           << 
". input: " << inputShape[inputDimIdx]
 
 2289                           << 
", init: " << initShape[initDimIdx];
 
 2295SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
 
 2296  int64_t rank = getInit().getType().getRank();
 
 2297  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
 
 2300ArrayAttr BroadcastOp::getIndexingMaps() {
 
 2302  int64_t rank = getInit().getType().getRank();
 
 2308void BroadcastOp::getEffects(
 
 2309    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 2324    auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
 
 2325    if (!defBroadcastOp)
 
 2330    Value init = broadcastOp.getInit();
 
 2334    for (
auto dim : llvm::seq<int64_t>(0, initRank)) {
 
 2335      if (!llvm::is_contained(dimensions, dim))
 
 2336        dimMap.push_back(dim);
 
 2338    for (
auto dim : defDimensions)
 
 2339      foldedDims.push_back(dimMap[dim]);
 
 2341    llvm::sort(foldedDims);
 
 2343        broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
 
 
 
 2348void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 2349                                              MLIRContext *context) {
 
 2350  results.
add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
 
 2357void linalg::YieldOp::print(OpAsmPrinter &p) {
 
 2358  if (getNumOperands() > 0)
 
 2359    p << 
' ' << getOperands();
 
 2361  if (getNumOperands() > 0)
 
 2362    p << 
" : " << getOperandTypes();
 
 2365ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 2366  SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
 
 2367  SmallVector<Type, 2> types;
 
 2377static LogicalResult 
verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
 
 2378  if (op.getNumOperands() != linalgOp.getNumDpsInits())
 
 2379    return op.emitOpError(
"expected number of yield values (")
 
 2380           << op.getNumOperands()
 
 2381           << 
") to match the number of inits / outs operands of the enclosing " 
 2382           << 
"LinalgOp (" << linalgOp.getNumDpsInits() << 
")";
 
 2384  for (
OpOperand &opOperand : op->getOpOperands()) {
 
 2386        linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
 
 2388    if (isa<MemRefType, RankedTensorType>(elementType))
 
 2390    if (opOperand.get().getType() != elementType)
 
 2391      return op.emitOpError(
"type of yield operand ")
 
 2392             << (opOperand.getOperandNumber() + 1) << 
" (" 
 2393             << opOperand.get().getType() << 
") doesn't match " 
 2394             << 
"the element type of the enclosing linalg.generic op (" 
 2395             << elementType << 
")";
 
 
 2400LogicalResult linalg::YieldOp::verify() {
 
 2401  auto *parentOp = (*this)->getParentOp();
 
 2402  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
 
 2403    return emitOpError(
"expected single non-empty parent region");
 
 2405  if (
auto linalgOp = dyn_cast<LinalgOp>(parentOp))
 
 2408  return emitOpError(
"expected parent op with LinalgOp interface");
 
 2415LogicalResult IndexOp::verify() {
 
 2416  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
 
 2418    return emitOpError(
"expected parent op with LinalgOp interface");
 
 2419  if (linalgOp.getNumLoops() <= getDim())
 
 2421           << getDim() << 
") to be lower than the number of loops (" 
 2422           << linalgOp.getNumLoops() << 
") of the enclosing LinalgOp";
 
 2426OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
 
 2427  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
 
 2432    return OpFoldResult{};
 
 2435  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
 
 2436  uint64_t dim = getDim();
 
 2437  assert(dim < loopBounds.size() && 
"Dim is out of bounds");
 
 2438  if (loopBounds[dim] == 1)
 
 2439    return IntegerAttr::get(IndexType::get(
getContext()), 0);
 
 2441  return OpFoldResult{};
 
 2446#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" 
 2448#define GET_OP_CLASSES 
 2449#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" 
 2451#define GET_OP_CLASSES 
 2452#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" 
 2453#define GET_OP_CLASSES 
 2454#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc" 
 2471  for (
unsigned i = 0; i < num; ++i)
 
 
 2478  auto rangeA = llvm::make_range(a.begin(), a.end());
 
 2479  auto rangeB = llvm::make_range(
b.begin(), 
b.end());
 
 2480  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
 
 2481  return llvm::to_vector<4>(concatRanges);
 
 
 2485  if (
auto memref = llvm::dyn_cast<MemRefType>(t)) {
 
 2487    for (
auto size : 
memref.getShape())
 
 2494    if (
auto as = 
memref.getMemorySpace()) {
 
 2495      if (
auto attr = llvm::dyn_cast<IntegerAttr>(as))
 
 2496        ss << 
"as" << attr.getInt();
 
 2502  if (
auto vec = llvm::dyn_cast<VectorType>(t)) {
 
 2505        vec.getShape(), [&](
int64_t i) { ss << i; }, [&]() { ss << 
"x"; });
 
 
 2518  assert(isa<LinalgOp>(op));
 
 2520  std::string fun = 
"";
 
 2522    if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
 
 2523      fun = stringifyEnum(ufa.getValue()).str() + 
"_";
 
 2524    } 
else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
 
 2525      fun = stringifyEnum(bfa.getValue()).str() + 
"_";
 
 2529  llvm::replace(name, 
'.', 
'_');
 
 2530  llvm::raw_string_ostream ss(name);
 
 2534      return std::string();
 
 
 2549  LogicalResult matchAndRewrite(LinalgOp op,
 
 2551    for (
OpOperand &opOperand : op->getOpOperands()) {
 
 2555      auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
 
 2558      if (llvm::is_contained(op.getShape(&opOperand), 0)) {
 
 2569struct FoldTensorCastConsumerOp : 
public OpRewritePattern<tensor::CastOp> {
 
 2570  using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
 
 2572  LogicalResult matchAndRewrite(tensor::CastOp castOp,
 
 2573                                PatternRewriter &rewriter)
 const override {
 
 2577    auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
 
 2584    if (castOp->getBlock() != linalgOp->getBlock())
 
 2587    OpBuilder::InsertionGuard guard(rewriter);
 
 2590    Location loc = linalgOp.getLoc();
 
 2591    OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
 
 2594        llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
 
 2600    OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
 
 2602        tensor::CastOp::create(rewriter, loc, resultType, outOperand->
get());
 
 2603    SmallVector<Value> newOperands = linalgOp.getDpsInputs();
 
 2604    SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
 
 2605                                      linalgOp.getDpsInits().end());
 
 2606    outputOperands[resultNumber] = newOperand;
 
 2607    newOperands.append(outputOperands.begin(), outputOperands.end());
 
 2609    SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
 
 2610                                  linalgOp->result_type_end());
 
 2611    resultTypes[resultNumber] = resultType;
 
 2612    Operation *newOp = 
clone(rewriter, linalgOp, resultTypes, newOperands);
 
 2615    Value castBack = tensor::CastOp::create(
 
 2619    results[resultNumber] = castBack;
 
 2628static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
 
 2629                        llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
 
 2630  for (OpOperand &opOperand : operands) {
 
 2631    if (linalgOp.isScalar(&opOperand))
 
 2633    Value src = opOperand.get();
 
 2634    auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
 
 2635    auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
 
 2641    ArrayRef<int64_t> sourceShape = sourceType.getShape();
 
 2643      if (
auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
 
 2644        Value castSource = castOp.getSource();
 
 2645        auto castSourceType =
 
 2646            llvm::dyn_cast<RankedTensorType>(castSource.
getType());
 
 2647        if (castSourceType && castSourceType.hasStaticShape())
 
 2648          sourceShape = castSourceType.getShape();
 
 2654    for (
unsigned i = 0; i < sourceShape.size(); i++) {
 
 2655      if (sourceType.isDynamicDim(i))
 
 2657      if (
auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
 
 2658        affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
 
 2668static void createNewOperandWithStaticSizes(
 
 2669    Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
 
 2670    llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
 
 2671    SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
 
 2672    bool &changeNeeded) {
 
 2673  Value src = opOperand->
get();
 
 2674  newOperands.push_back(src);
 
 2675  if (linalgOp.isScalar(opOperand))
 
 2677  auto sourceType = llvm::cast<RankedTensorType>(src.
getType());
 
 2678  Type resultType = sourceType;
 
 2679  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
 
 2680    resultTypes.push_back(resultType);
 
 2683  ArrayRef<int64_t> sourceShape = sourceType.getShape();
 
 2684  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
 
 2685  SmallVector<int64_t> newShape;
 
 2688  bool newOperandNeeded = 
false;
 
 2689  for (
unsigned i = 0; i < sourceShape.size(); i++) {
 
 2690    int64_t dimShape = sourceShape[i];
 
 2691    AffineExpr dimExpr = sourceMap.
getResult(i);
 
 2692    if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
 
 2693      newShape.push_back(dimShape);
 
 2699    newShape.push_back(affineExprToSize[dimExpr]);
 
 2700    newOperandNeeded = 
true;
 
 2702  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
 
 2703                                     sourceType.getEncoding());
 
 2704  if (newOperandNeeded) {
 
 2705    changeNeeded = 
true;
 
 2708    Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
 
 2710    newOperands[index] = newOperand;
 
 2712  if (linalgOp.isDpsInit(opOperand))
 
 2713    resultTypes.push_back(resultType);
 
 2719struct InferStaticShapeOfOperands : 
public OpInterfaceRewritePattern<LinalgOp> {
 
 2720  using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
 
 2722  LogicalResult matchAndRewrite(LinalgOp linalgOp,
 
 2723                                PatternRewriter &rewriter)
 const override {
 
 2724    if (!linalgOp.hasPureTensorSemantics())
 
 2728    if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
 
 2729          return !map.isProjectedPermutation();
 
 2734    llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
 
 2735    Location loc = linalgOp.getLoc();
 
 2739    populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
 
 2741    SmallVector<Value> newOperands;
 
 2742    SmallVector<Type> resultTypes;
 
 2746    bool changeNeeded = 
false;
 
 2747    newOperands.reserve(linalgOp->getNumOperands());
 
 2748    resultTypes.reserve(linalgOp.getNumDpsInits());
 
 2751    for (OpOperand &opOperand : linalgOp->getOpOperands()) {
 
 2752      createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
 
 2753                                      affineExprToSize, linalgOp, newOperands,
 
 2754                                      resultTypes, changeNeeded);
 
 2763    Operation *newOp = 
clone(rewriter, linalgOp, resultTypes, newOperands);
 
 2764    SmallVector<Value> replacements;
 
 2766    for (
auto it : llvm::zip(linalgOp->getResults(), newOp->
getResults())) {
 
 2767      Value newResult = std::get<1>(it);
 
 2768      Value oldResult = std::get<0>(it);
 
 2769      Type newType = newResult.
getType();
 
 2770      Type oldType = oldResult.
getType();
 
 2771      replacements.push_back(
 
 2772          (newType != oldType)
 
 2773              ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
 
 2776    rewriter.
replaceOp(linalgOp, replacements);
 
 2790LogicalResult SoftmaxOp::verify() {
 
 2791  ShapedType inputType = getInputOperandType();
 
 2792  ShapedType outputType = getOutputOperandType();
 
 2794  ArrayRef<int64_t> inputShape = inputType.getShape();
 
 2795  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 2799  int64_t inputRank = getInputOperandRank();
 
 2800  int64_t dimension = getDimension();
 
 2801  if ((dimension < 0) || (dimension >= inputRank))
 
 2802    return emitOpError(
"incorrect dimension specified");
 
 2807SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
 
 2808  int64_t operandRank = getInputOperandRank();
 
 2809  SmallVector<Range> loopBounds(operandRank);
 
 2810  Location loc = getLoc();
 
 2813  Value source = getInput();
 
 2814  for (
auto dim : llvm::seq<int64_t>(0, operandRank)) {
 
 2815    loopBounds[dim].offset = zero;
 
 2816    loopBounds[dim].size = 
getDimValue(builder, loc, source, dim);
 
 2817    loopBounds[dim].stride = one;
 
 2822SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
 
 2823  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
 
 2824                                                 utils::IteratorType::parallel);
 
 2825  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
 
 2826  return iteratorTypes;
 
 2829FailureOr<TilingResult>
 
 2830SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 2831                                  ArrayRef<OpFoldResult> offsets,
 
 2832                                  ArrayRef<OpFoldResult> sizes) {
 
 2833  int64_t rank = getInputOperandRank();
 
 2835  SmallVector<OpFoldResult> strides(rank, oneAttr);
 
 2836  SmallVector<Value> tiledOperands;
 
 2837  Operation *inputSlice =
 
 2838      getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
 
 2840    return emitOpError(
"failed to compute input slice");
 
 2842  tiledOperands.emplace_back(inputSlice->
getResult(0));
 
 2843  Operation *outputSlice =
 
 2844      getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
 
 2846    return emitOpError(
"failed to compute output slice");
 
 2848  tiledOperands.emplace_back(outputSlice->
getResult(0));
 
 2850  SmallVector<Type, 4> resultTypes;
 
 2851  if (hasPureTensorSemantics())
 
 2852    resultTypes.push_back(tiledOperands[1].
getType());
 
 2853  Operation *tiledOp =
 
 2854      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
 2856  return TilingResult{
 
 2859      llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
 
 2862LogicalResult SoftmaxOp::getResultTilePosition(
 
 2863    OpBuilder &builder, 
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
 
 2864    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
 
 2865    SmallVector<OpFoldResult> &resultSizes) {
 
 2866  if (resultNumber == 0) {
 
 2867    resultOffsets.assign(offsets.begin(), offsets.end());
 
 2868    resultSizes.assign(sizes.begin(), sizes.end());
 
 2875LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 
 2880SoftmaxOp::reifyResultShapes(OpBuilder &
b,
 
 2882  SmallVector<OpFoldResult> shapes;
 
 2883  Location loc = getOperation()->getLoc();
 
 2884  IRRewriter rewriter(
b);
 
 2885  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
 
 2886  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
 
 2887  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
 
 2888    if (!outputShapedType.isDynamicDim(dim)) {
 
 2890      shapes.push_back(
b.getIndexAttr(inputShapedType.getDimSize(dim)));
 
 2897  reifiedReturnShapes.emplace_back(std::move(shapes));
 
 2901void SoftmaxOp::getEffects(
 
 2902    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 2904  for (
auto [index, operand] : llvm::enumerate(getDpsInputs())) {
 
 2905    if (!llvm::isa<MemRefType>(operand.
getType()))
 
 2908                         &getOperation()->getOpOperand(index), 0,
 
 2913  for (OpOperand &operand : getDpsInitsMutable()) {
 
 2914    if (!llvm::isa<MemRefType>(operand.get().
getType()))
 
 2945static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
 
 2947                                    int64_t dim, 
bool allParallel = 
false) {
 
 2949                                                 utils::IteratorType::parallel);
 
 2951    iteratorTypes[dim] = utils::IteratorType::reduction;
 
 2955  for (
int i = 0; i < inputRank; i++) {
 
 2962  return std::make_tuple(iteratorTypes, indexingMaps);
 
 
 2967template <
typename T>
 
 2970  auto inputType = cast<ShapedType>(input.
getType());
 
 2972  int64_t inputRank = inputShape.size();
 
 2973  auto [iteratorTypes, indexingMaps] =
 
 2975  assert(indexingMaps.size() == 2 &&
 
 2976         "We should have two maps: 1 for the input, 1 for the output");
 
 2977  assert(indexingMaps[0].isIdentity() && 
"input map should be identity");
 
 2979  auto genericOp = linalg::GenericOp::create(
 
 2980      builder, loc, output.
getType(), input, output, indexingMaps,
 
 2982        Value result = T::create(b, loc, args[0], args[1]);
 
 2983        linalg::YieldOp::create(b, loc, result);
 
 2985  return genericOp.getResult(0);
 
 
 2993  auto inputType = cast<ShapedType>(input.
getType());
 
 2995  int64_t inputRank = inputShape.size();
 
 2997      builder, inputRank, dim, 
true);
 
 2998  assert(indexingMaps.size() == 2 && 
"We should have one map for each input");
 
 2999  assert(indexingMaps[0].isIdentity() && 
"input map should be identity");
 
 3001  indexingMaps.push_back(indexingMaps[0]);
 
 3002  auto genericOp = linalg::GenericOp::create(
 
 3004      indexingMaps, iteratorTypes,
 
 3006        Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
 
 3007        Value result = math::ExpOp::create(b, loc, diff);
 
 3008        linalg::YieldOp::create(b, loc, result);
 
 3010  return genericOp.getResult(0);
 
 
 3020  auto inputType = cast<ShapedType>(numerator.
getType());
 
 3022  int64_t inputRank = inputShape.size();
 
 3024      builder, inputRank, dim, 
true);
 
 3025  assert(indexingMaps.size() == 2 &&
 
 3026         "We should have one map for each input (2)");
 
 3027  assert(indexingMaps[0].isIdentity() && 
"Numerator map should be identity");
 
 3029  indexingMaps.push_back(indexingMaps[0]);
 
 3030  auto genericOp = linalg::GenericOp::create(
 
 3032      output, indexingMaps, iteratorTypes,
 
 3034        Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
 
 3035        linalg::YieldOp::create(b, loc, result);
 
 3037  return genericOp.getResult(0);
 
 
 3059FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &
b) {
 
 3060  OpBuilder::InsertionGuard guard(
b);
 
 3061  b.setInsertionPoint(*
this);
 
 3062  Location loc = getLoc();
 
 3063  Value input = getInput();
 
 3064  ShapedType inputType = getInputOperandType();
 
 3065  Type elementType = inputType.getElementType();
 
 3066  int64_t reductionDim = getDimension();
 
 3068  Value output = getOutput();
 
 3069  dims.erase(dims.begin() + reductionDim);
 
 3071  Value outputReduce = tensor::EmptyOp::create(
b, loc, dims, elementType);
 
 3073                                                 elementType, 
b, loc,
 
 3075  Value neutralForMaxFInit =
 
 3076      linalg::FillOp::create(
b, loc, Value{neutralForMaxF}, outputReduce)
 
 3088      linalg::FillOp::create(
b, loc, Value{zero}, outputReduce).
result();
 
 3094      buildDivOp(
b, loc, numerator, denominator, output, reductionDim);
 
 3095  return SmallVector<Value>{
result};
 
 3102LogicalResult WinogradFilterTransformOp::verify() {
 
 3103  auto filterType = cast<ShapedType>(getFilter().
getType());
 
 3104  ArrayRef<int64_t> filterShape = filterType.getShape();
 
 3105  int64_t filterH = filterShape[getFilterHDim()];
 
 3106  int64_t filterW = filterShape[getFilterWDim()];
 
 3107  WinogradConv2DFmr fmr = getFmr();
 
 3111  if (filterH != r && filterH != 1)
 
 3112    return emitOpError(
"expect filter height either equals to r or 1");
 
 3113  if (filterW != r && filterW != 1)
 
 3114    return emitOpError(
"expect filter width either equals to r or 1");
 
 3115  if (filterH == 1 && filterW == 1)
 
 3116    return emitOpError(
"expect either filter height or width equals to r");
 
 3118  SmallVector<int64_t> expectedOutputShape;
 
 3119  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
 
 3120  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
 
 3121  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
 
 3122  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
 
 3124  auto outputType = cast<ShapedType>(getOutput().
getType());
 
 3125  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 3127    return emitOpError(
"the output shape is not expected");
 
 3133WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
 
 3134  Location loc = getLoc();
 
 3137  Value filter = getFilter();
 
 3138  int64_t filterRank = getFilterOperandRank();
 
 3139  SmallVector<Range> loopBounds(filterRank);
 
 3140  for (
unsigned dim = 0; dim < filterRank; ++dim) {
 
 3141    loopBounds[dim].offset = zeroAttr;
 
 3142    loopBounds[dim].size = 
getDimValue(builder, loc, filter, dim);
 
 3143    loopBounds[dim].stride = oneAttr;
 
 3148SmallVector<utils::IteratorType>
 
 3149WinogradFilterTransformOp::getLoopIteratorTypes() {
 
 3150  int64_t filterRank = getFilterOperandRank();
 
 3151  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
 
 3152                                                 utils::IteratorType::parallel);
 
 3153  return iteratorTypes;
 
 3156LogicalResult WinogradFilterTransformOp::getResultTilePosition(
 
 3157    OpBuilder &builder, 
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
 
 3158    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
 
 3159    SmallVector<OpFoldResult> &resultSizes) {
 
 3161  ShapedType filterType = getFilterOperandType();
 
 3162  ArrayRef<int64_t> filterShape = filterType.getShape();
 
 3163  int64_t filterH = filterShape[getFilterHDim()];
 
 3164  int64_t filterW = filterShape[getFilterWDim()];
 
 3165  WinogradConv2DFmr fmr = getFmr();
 
 3168  int64_t alpha = m + r - 1;
 
 3169  int64_t alphaH = filterH != 1 ? alpha : 1;
 
 3170  int64_t alphaW = filterW != 1 ? alpha : 1;
 
 3174  resultOffsets.append(
 
 3175      {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
 
 3177      {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
 
 3188FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
 
 3189    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
 
 3190    ArrayRef<OpFoldResult> sizes) {
 
 3193  ShapedType filterType = getFilterOperandType();
 
 3194  ArrayRef<int64_t> filterShape = filterType.getShape();
 
 3195  int64_t filterH = filterShape[getFilterHDim()];
 
 3196  int64_t filterW = filterShape[getFilterWDim()];
 
 3199  SmallVector<Value> tiledOperands;
 
 3200  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
 3202  sliceOffsets.append(
 
 3203      {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
 
 3204  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
 
 3205                     sizes[getFilterCDim()]});
 
 3206  int64_t filterRank = getFilterOperandRank();
 
 3207  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
 
 3208  Location loc = getLoc();
 
 3209  auto filterSlice = tensor::ExtractSliceOp::create(
 
 3210      builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
 
 3211  tiledOperands.emplace_back(filterSlice);
 
 3213  SmallVector<OpFoldResult> resultOffsets, resultSizes;
 
 3218  int64_t outputRank = getOutputOperandRank();
 
 3219  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
 
 3220  auto outputSlice = tensor::ExtractSliceOp::create(
 
 3221      builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
 
 3222  tiledOperands.emplace_back(outputSlice);
 
 3224  SmallVector<Type> resultTypes;
 
 3225  resultTypes.push_back(tiledOperands[1].
getType());
 
 3226  Operation *tiledOp =
 
 3227      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
 3229  return TilingResult{
 
 3232      llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
 
 3239LogicalResult WinogradInputTransformOp::verify() {
 
 3240  auto inputType = cast<ShapedType>(getInput().
getType());
 
 3241  ArrayRef<int64_t> inputShape = inputType.getShape();
 
 3242  int64_t inputH = inputShape[getInputHDim()];
 
 3243  int64_t inputW = inputShape[getInputWDim()];
 
 3244  WinogradConv2DFmr fmr = getFmr();
 
 3247  int64_t tileSize = m + r - 1;
 
 3249  auto outputType = cast<ShapedType>(getOutput().
getType());
 
 3250  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 3251  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
 
 3252  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
 
 3254  SmallVector<int64_t> expectedOutputShape(6, inputH);
 
 3255  if (ShapedType::isDynamic(inputH)) {
 
 3256    expectedOutputShape[getOutputAlphaHDim()] = tileSize;
 
 3257    expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
 
 3259    expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
 
 3260    expectedOutputShape[getOutputTileHDim()] =
 
 3261        leftTransform ? (inputH - (r - 1)) / m : inputH;
 
 3263  if (ShapedType::isDynamic(inputW)) {
 
 3264    expectedOutputShape[getOutputAlphaWDim()] = tileSize;
 
 3265    expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
 
 3267    expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
 
 3268    expectedOutputShape[getOutputTileWDim()] =
 
 3269        rightTransform ? (inputW - (r - 1)) / m : inputW;
 
 3271  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
 
 3272  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
 
 3275    return emitOpError(
"the output shape is not expected");
 
 3281WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
 
 3282  Location loc = getLoc();
 
 3285  Value output = getOutput();
 
 3286  int64_t outputRank = getOutputOperandRank();
 
 3287  SmallVector<Range> loopBounds(outputRank);
 
 3288  for (
unsigned dim = 0; dim < outputRank; ++dim) {
 
 3289    loopBounds[dim].offset = zeroAttr;
 
 3291    loopBounds[dim].size = 
getDimValue(builder, loc, output, dim);
 
 3292    loopBounds[dim].stride = oneAttr;
 
 3297SmallVector<utils::IteratorType>
 
 3298WinogradInputTransformOp::getLoopIteratorTypes() {
 
 3299  int64_t outputRank = getOutputOperandRank();
 
 3300  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
 
 3301                                                 utils::IteratorType::parallel);
 
 3302  return iteratorTypes;
 
 3305LogicalResult WinogradInputTransformOp::getResultTilePosition(
 
 3306    OpBuilder &builder, 
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
 
 3307    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
 
 3308    SmallVector<OpFoldResult> &resultSizes) {
 
 3310  ShapedType outputType = getOutputOperandType();
 
 3311  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 3312  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
 
 3313  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
 
 3315  WinogradConv2DFmr fmr = getFmr();
 
 3318  int64_t alpha = m + r - 1;
 
 3319  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
 
 3320  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
 
 3325  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
 
 3326                        offsets[getOutputTileWDim()], offsets[getOutputNDim()],
 
 3327                        offsets[getOutputCDim()]});
 
 3328  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
 
 3329                      sizes[getOutputTileWDim()], sizes[getOutputNDim()],
 
 3330                      sizes[getOutputCDim()]});
 
 3341FailureOr<TilingResult>
 
 3342WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
 
 3343                                                 ArrayRef<OpFoldResult> offsets,
 
 3344                                                 ArrayRef<OpFoldResult> sizes) {
 
 3346  WinogradConv2DFmr fmr = getFmr();
 
 3350  ShapedType outputType = getOutputOperandType();
 
 3351  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 3352  int64_t alphaH = outputShape[getOutputAlphaHDim()];
 
 3353  int64_t alphaW = outputShape[getOutputAlphaWDim()];
 
 3355  Location loc = getLoc();
 
 3357  auto identityAffineMap =
 
 3359  auto offsetAffineMap =
 
 3362      builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
 
 3363      offsets[getOutputTileHDim()]);
 
 3365      builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
 
 3366      offsets[getOutputTileWDim()]);
 
 3370      builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
 
 3372      builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
 
 3374  SmallVector<Value> tiledOperands;
 
 3375  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
 3377  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
 
 3378  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
 
 3379  sliceOffsets.append(
 
 3380      {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
 
 3381  OpFoldResult sizeH =
 
 3382      alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
 
 3383  OpFoldResult sizeW =
 
 3384      alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
 
 3386      {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
 
 3387  int64_t inputRank = getInputOperandRank();
 
 3388  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
 
 3389  auto inputSlice = tensor::ExtractSliceOp::create(
 
 3390      builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
 
 3391  tiledOperands.emplace_back(inputSlice);
 
 3393  SmallVector<OpFoldResult> resultOffsets, resultSizes;
 
 3398  int64_t outputRank = getOutputOperandRank();
 
 3399  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
 
 3400  auto outputSlice = tensor::ExtractSliceOp::create(
 
 3401      builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
 
 3402  tiledOperands.emplace_back(outputSlice);
 
 3404  SmallVector<Type> resultTypes;
 
 3405  resultTypes.push_back(tiledOperands[1].
getType());
 
 3406  Operation *tiledOp =
 
 3407      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
 3409  return TilingResult{
 
 3412      llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
 
 3419LogicalResult WinogradOutputTransformOp::verify() {
 
 3420  auto valueType = cast<ShapedType>(getValue().
getType());
 
 3421  ArrayRef<int64_t> valueShape = valueType.getShape();
 
 3422  int64_t valueH = valueShape[getValueAlphaHDim()];
 
 3423  int64_t valueW = valueShape[getValueAlphaWDim()];
 
 3424  int64_t valueTileH = valueShape[getValueTileHDim()];
 
 3425  int64_t valueTileW = valueShape[getValueTileWDim()];
 
 3426  WinogradConv2DFmr fmr = getFmr();
 
 3429  bool leftTransform = valueH != 1;
 
 3430  bool rightTransform = valueW != 1;
 
 3432  int64_t outputRank = getOutputOperandRank();
 
 3433  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
 
 3434  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
 
 3435    expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
 
 3437    if (valueH != (leftTransform ? m + r - 1 : 1))
 
 3438      return emitOpError(
"expect input height equals to input tile size");
 
 3439    expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
 
 3441  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
 
 3442    expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
 
 3444    if (valueW != (rightTransform ? m + r - 1 : 1))
 
 3445      return emitOpError(
"expect input width equals to input tile size");
 
 3446    expectedOutputShape[getOutputWDim()] =
 
 3447        (rightTransform ? m : 1) * valueTileW;
 
 3449  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
 
 3450  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
 
 3452  auto outputType = cast<ShapedType>(getOutput().
getType());
 
 3453  ArrayRef<int64_t> outputShape = outputType.getShape();
 
 3455    return emitOpError(
"the output shape is not expected");
 
 3461WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
 
 3462  Location loc = getLoc();
 
 3465  Value value = getValue();
 
 3466  int64_t valueRank = getValueOperandRank();
 
 3467  SmallVector<Range> loopBounds(valueRank);
 
 3468  for (
unsigned dim = 0; dim < valueRank; ++dim) {
 
 3469    loopBounds[dim].offset = zeroAttr;
 
 3471    loopBounds[dim].size = 
getDimValue(builder, loc, value, dim);
 
 3472    loopBounds[dim].stride = oneAttr;
 
 3477SmallVector<utils::IteratorType>
 
 3478WinogradOutputTransformOp::getLoopIteratorTypes() {
 
 3479  int64_t valueRank = getValueOperandRank();
 
 3480  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
 
 3481                                                 utils::IteratorType::parallel);
 
 3482  return iteratorTypes;
 
 3485LogicalResult WinogradOutputTransformOp::getResultTilePosition(
 
 3486    OpBuilder &builder, 
unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
 
 3487    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
 
 3488    SmallVector<OpFoldResult> &resultSizes) {
 
 3489  WinogradConv2DFmr fmr = getFmr();
 
 3493  Location loc = getLoc();
 
 3495  auto identityAffineMap =
 
 3500  ShapedType valueType = getValueOperandType();
 
 3501  ArrayRef<int64_t> valueShape = valueType.getShape();
 
 3502  int64_t valueH = valueShape[0];
 
 3503  int64_t valueW = valueShape[1];
 
 3505      builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
 
 3506      offsets[getValueTileHDim()]);
 
 3508      builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
 
 3509      offsets[getValueTileWDim()]);
 
 3511      builder, loc, affineMap, sizes[getValueTileHDim()]);
 
 3513      builder, loc, affineMap, sizes[getValueTileWDim()]);
 
 3516  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
 
 3517  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
 
 3518  OpFoldResult sizeH =
 
 3519      valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
 
 3520  OpFoldResult sizeW =
 
 3521      valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
 
 3523  resultOffsets.append(
 
 3524      {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
 
 3526      {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
 
 3536FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
 
 3537    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
 
 3538    ArrayRef<OpFoldResult> sizes) {
 
 3541  Location loc = getLoc();
 
 3542  SmallVector<Value> tiledOperands;
 
 3543  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
 
 3545  ShapedType valueType = getValueOperandType();
 
 3546  ArrayRef<int64_t> valueShape = valueType.getShape();
 
 3547  int64_t alphaH = valueShape[getValueAlphaHDim()];
 
 3548  int64_t alphaW = valueShape[getValueAlphaWDim()];
 
 3552  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
 
 3553                       offsets[getValueTileWDim()], offsets[getValueNDim()],
 
 3554                       offsets[getValueFDim()]});
 
 3555  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
 
 3556                     sizes[getValueTileWDim()], sizes[getValueNDim()],
 
 3557                     sizes[getValueFDim()]});
 
 3558  int64_t valueRank = getValueOperandRank();
 
 3559  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
 
 3560  auto valueSlice = tensor::ExtractSliceOp::create(
 
 3561      builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
 
 3562  tiledOperands.emplace_back(valueSlice);
 
 3564  SmallVector<OpFoldResult> resultOffsets, resultSizes;
 
 3569  int64_t outputRank = getOutputOperandRank();
 
 3570  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
 
 3571  auto outputSlice = tensor::ExtractSliceOp::create(
 
 3572      builder, loc, getOutput(), resultOffsets, resultSizes, strides);
 
 3573  tiledOperands.emplace_back(outputSlice);
 
 3575  SmallVector<Type> resultTypes;
 
 3576  resultTypes.push_back(tiledOperands[1].
getType());
 
 3577  Operation *tiledOp =
 
 3578      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
 3580  return TilingResult{
 
 3583      llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
 
 3597  llvm::set_union(explicitSet, defaultSet);
 
 3598  return explicitSet == defaultSet;
 
 
 3618      matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
 
 3620  auto opIndexingMap = opIndexingMaps[opIndex];
 
 3621  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
 3624    return matmulOp->emitOpError()
 
 3625           << 
"Unexpected dim expression in map result.";
 
 3628    if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
 
 3629      return matmulOp->emitOpError()
 
 3630             << 
"Invalid broadcast requested, should be (d2).";
 
 
 3639template <
typename OpTy>
 
 3642                                     AffineMap defaultIndexingMap, 
bool isLHS) {
 
 3643  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
 
 3644          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
 
 3645         "Expected BatchMatmulOp or BatchReduceMatmulOp");
 
 3648    return batchVariantMatmulOp->emitOpError()
 
 3649           << 
"Unexpected result dim expression (outside the set of default " 
 3654    return batchVariantMatmulOp->emitOpError()
 
 3655           << 
"no. of result dim expressions exceeds 3.";
 
 3657  auto hasValidBatchDim = [](
AffineMap map) {
 
 3664    if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
 
 3665      return batchVariantMatmulOp->emitOpError()
 
 3666             << 
"Invalid broadcast requested.";
 
 3667  } 
else if (!hasValidBatchDim(opIndexingMap)) {
 
 3668    return batchVariantMatmulOp->emitOpError()
 
 3669           << 
"Invalid batch dimension expression.";
 
 
 3677template <
typename OpTy>
 
 3680  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
 
 3681          isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
 
 3682         "Expected BatchMatmulOp or BatchReduceMatmulOp");
 
 3683  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
 
 3686    return batchVariantMatmulOp->emitOpError()
 
 3687           << 
"expects 3 dims, but got (" << opIndexingMap.
getNumResults()
 
 3690  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
 
 3692    return batchVariantMatmulOp->emitOpError()
 
 3693           << 
"expects 2 dims, but got (" << opIndexingMap.
getNumResults()
 
 3697  auto areValidOutputResultDim = [&](
AffineMap outputMap) {
 
 3698    return isa<BatchMatmulOp>(batchVariantMatmulOp)
 
 3699               ? outputMap.getResult(0).isFunctionOfDim(0) &&
 
 3700                     outputMap.getResult(1).isFunctionOfDim(1) &&
 
 3701                     outputMap.getResult(2).isFunctionOfDim(2)
 
 3702               : outputMap.getResult(0).isFunctionOfDim(1) &&
 
 3703                     outputMap.getResult(1).isFunctionOfDim(2);
 
 3706  if (!areValidOutputResultDim(opIndexingMap)) {
 
 3707    return batchVariantMatmulOp->emitOpError()
 
 3708           << 
"Invalid output map result dimension.";
 
 
 3717template <
typename OpTy>
 
 3722      batchVariantMatmulOp.getIndexingMapsArray();
 
 3724      batchVariantMatmulOp.getDefaultIndexingMaps(
 
 3725          batchVariantMatmulOp->getContext());
 
 3727  if (opIndexingMaps.size() != 3)
 
 3728    return batchVariantMatmulOp->emitOpError()
 
 3729           << 
"Indexing_map attribute must have 3 affine maps.";
 
 3731  auto opIndexingMap = opIndexingMaps[opIndex];
 
 3732  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
 
 3740                             defaultIndexingMap, opIndex == 0)))
 
 
 3750  if (m == 2 && r == 3)
 
 3751    return WinogradConv2DFmr::F_2_3;
 
 3752  if (m == 4 && r == 3)
 
 3753    return WinogradConv2DFmr::F_4_3;
 
 3754  if (m == 2 && r == 5)
 
 3755    return WinogradConv2DFmr::F_2_5;
 
 3756  return std::nullopt;
 
 
 3761  case WinogradConv2DFmr::F_2_3:
 
 3763  case WinogradConv2DFmr::F_4_3:
 
 3765  case WinogradConv2DFmr::F_2_5:
 
 
 3774static FailureOr<SmallVector<SmallVector<int64_t>>>
 
 3777  for (
auto map : maps) {
 
 3778    AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
 
 3782    for (
auto result : attr.getAffineMap().getResults()) {
 
 3783      auto dim = dyn_cast<AffineDimExpr>(
result);
 
 3786      pos.push_back(dim.getPosition());
 
 3788    positions.push_back(pos);
 
 
 3801  return indexingMaps;
 
 3804bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
 
 3805  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 3808  if (maps.size() != 3)
 
 3813  return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
 
 3814         (*positions)[1] == SmallVector<int64_t>{2, 1} &&
 
 3815         (*positions)[2] == SmallVector<int64_t>{0, 1};
 
 3818SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
 
 3819  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
 
 3820                                          utils::IteratorType::parallel,
 
 3821                                          utils::IteratorType::reduction};
 
 3824unsigned MatmulOp::getNumRegionArgs() { 
return 3; }
 
 3826std::string MatmulOp::getLibraryCallName() {
 
 3830bool MatmulOp::hasDynamicIndexingMaps() { 
return true; }
 
 3834bool MatmulOp::hasUserDefinedMaps() {
 
 3835  SmallVector<AffineMap, 3> defaultMaps =
 
 3837  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
 
 3838  return defaultMaps != explicitMaps;
 
 3843void MatmulOp::regionBuilder(ImplicitLocOpBuilder &
b, 
Block &block,
 
 3844                             ArrayRef<NamedAttribute> attrs,
 
 3847    emitError() << 
"MatmulOp regionBuilder expects 3 args, got " 
 3852         "MatmulOp regionBuilder expects 3 args");
 
 3853  RegionBuilderHelper helper(
b, block);
 
 3854  SmallVector<Value> yields;
 
 3856  TypeFn castVal = TypeFn::cast_signed;
 
 3857  const auto *castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
 
 3858    return attr.
getName() == 
"cast";
 
 3860  if (castIter != attrs.end()) {
 
 3861    if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
 
 3869  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, 
emitError);
 
 3872  Value value4 = helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2),
 
 3876  yields.push_back(value4);
 
 3877  helper.yieldOutputs(yields);
 
 3887bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
 
 3888  assert(bcastMap.
getNumResults() == 1 && 
"Expected single result dim expr.");
 
 3889  AffineExpr expr = bcastMap.
getResult(0);
 
 3903  if (llvm::any_of(arrayAttr,
 
 3904                   [](
auto elt) { 
return !dyn_cast<AffineMapAttr>(elt); }))
 
 3906           << 
"element of indexing_maps array is not an affine_map";
 
 
 3913  if (failed(indexingMapsAttr))
 
 3916  if (*indexingMapsAttr == 
nullptr) {
 
 3917    auto indexingMapAttrs = llvm::map_to_vector(
 
 3918        MatmulOp::getDefaultIndexingMaps(parser.
getContext()),
 
 3923  result.addAttribute(
"indexing_maps", *indexingMapsAttr);
 
 3925                                MatmulOp::getRegionBuilder());
 
 3928void MatmulOp::print(OpAsmPrinter &p) {
 
 3929  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
 
 3930      MatmulOp::getDefaultIndexingMaps(
getContext()),
 
 3931      [](AffineMap map) -> Attribute { 
return AffineMapAttr::get(map); });
 
 3932  if (!llvm::equal(getIndexingMaps(), indexingMaps))
 
 3933    p << 
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
 
 3935  std::array<StringRef, 3> elidedAttrs = {
 
 3936      "operandSegmentSizes", 
"linalg.memoized_indexing_maps", 
"indexing_maps"};
 
 3942LogicalResult MatmulOp::verify() {
 
 3944  if (!hasUserDefinedMaps())
 
 3947  for (
unsigned opIndex = 0; opIndex < 2; opIndex++) {
 
 3954LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 
 3958void MatmulOp::getEffects(
 
 3959    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 3961  if (hasPureTensorSemantics())
 
 3970SmallVector<AffineMap>
 
 3971MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
 
 3972  AffineExpr d0, d1, d2;
 
 3978  return {mapLHS, mapRHS, mapOut};
 
 3982  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 3985  if (maps.size() != 3)
 
 3988  if (failed(positions))
 
 
 4000                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4008  build(builder, state, inputs, outputs, attributes);
 
 4009  auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
 
 4010  assert(res && 
"builder didn't return the right type");
 
 
 4020                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4029  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
 
 4030  auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
 
 4031  assert(res && 
"builder didn't return the right type");
 
 
 4041  result.addAttribute(
"cast", cast);
 
 4043                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4052  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
 
 4053  auto res = dyn_cast<MatmulTransposeAOp>(builder.
create(state));
 
 4054  assert(res && 
"builder didn't return the right type");
 
 
 4059  return dyn_cast_or_null<linalg::MatmulOp>(op) &&
 
 4061             op->
getAttr(
"indexing_maps"));
 
 
 4065MatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
 
 4072  return {mapLHS, mapRHS, mapOut};
 
 4076  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 4079  if (maps.size() != 3)
 
 4082  if (failed(positions))
 
 
 4094                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4102  build(builder, state, inputs, outputs, attributes);
 
 4103  auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
 
 4104  assert(res && 
"builder didn't return the right type");
 
 
 4114                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4123  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
 
 4124  auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
 
 4125  assert(res && 
"builder didn't return the right type");
 
 
 4135  result.addAttribute(
"cast", cast);
 
 4137                MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
 
 
 4146  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
 
 4147  auto res = dyn_cast<MatmulTransposeBOp>(builder.
create(state));
 
 4148  assert(res && 
"builder didn't return the right type");
 
 
 4153  return dyn_cast_or_null<linalg::MatmulOp>(op) &&
 
 4155             op->
getAttr(
"indexing_maps"));
 
 
 4159BatchMatmulTransposeAOp::getDefaultIndexingMaps(
OpBuilder &builder) {
 
 4166  return {mapLHS, mapRHS, mapOut};
 
 4170  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 4173  if (maps.size() != 3)
 
 4176  if (failed(positions))
 
 
 4187                BatchMatmulOp::getRegionBuilder(),
 
 4188                getDefaultIndexingMaps(builder));
 
 
 4196  build(builder, state, inputs, outputs, attributes);
 
 4197  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
 
 4198  assert(res && 
"builder didn't return the right type");
 
 
 4207                BatchMatmulOp::getRegionBuilder(),
 
 4208                getDefaultIndexingMaps(builder));
 
 
 4217  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
 
 4218  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
 
 4219  assert(res && 
"builder didn't return the right type");
 
 
 4227  result.addAttribute(
"cast", cast);
 
 4229                BatchMatmulOp::getRegionBuilder(),
 
 4230                getDefaultIndexingMaps(builder));
 
 
 4239  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
 
 4240  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.
create(state));
 
 4241  assert(res && 
"builder didn't return the right type");
 
 
 4246  return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
 
 4248             op->
getAttr(
"indexing_maps"));
 
 
 4252BatchMatmulTransposeBOp::getDefaultIndexingMaps(
OpBuilder &builder) {
 
 4259  return {mapLHS, mapRHS, mapOut};
 
 4263  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 4266  if (maps.size() != 3)
 
 4269  if (failed(positions))
 
 
 4280                BatchMatmulOp::getRegionBuilder(),
 
 4281                getDefaultIndexingMaps(builder));
 
 
 4289  build(builder, state, inputs, outputs, attributes);
 
 4290  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
 
 4291  assert(res && 
"builder didn't return the right type");
 
 
 4300                BatchMatmulOp::getRegionBuilder(),
 
 4301                getDefaultIndexingMaps(builder));
 
 
 4310  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
 
 4311  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
 
 4312  assert(res && 
"builder didn't return the right type");
 
 
 4320  result.addAttribute(
"cast", cast);
 
 4322                BatchMatmulOp::getRegionBuilder(),
 
 4323                getDefaultIndexingMaps(builder));
 
 
 4332  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
 
 4333  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.
create(state));
 
 4334  assert(res && 
"builder didn't return the right type");
 
 
 4339  return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
 
 4341             op->
getAttr(
"indexing_maps"));
 
 
 4349  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
 
 4360    auto dimExpr = dyn_cast<AffineDimExpr>(
result);
 
 4361    assert(dimExpr && 
"affine_map is a projected permutation");
 
 4362    dimsInOutput[dimExpr.getPosition()] = 
true;
 
 4366  for (
auto dimOccursInOutput : dimsInOutput)
 
 4367    iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
 
 4368                                              : utils::IteratorType::reduction);
 
 4370  return iteratorTypes;
 
 4373unsigned ContractOp::getNumRegionArgs() { 
return 3; }
 
 4376void ContractOp::regionBuilder(ImplicitLocOpBuilder &
b, 
Block &block,
 
 4377                               ArrayRef<NamedAttribute> attrs,
 
 4380    emitError() << 
"ContractOp regionBuilder expects 3 args, got " 
 4385         "ContractOp regionBuilder expects 3 args");
 
 4386  RegionBuilderHelper helper(
b, block);
 
 4388  TypeFn castSignedness = TypeFn::cast_signed;
 
 4389  auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
 
 4390    return attr.
getName() == 
"cast";
 
 4392  if (castIter != attrs.end()) {
 
 4393    if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
 
 4399  Value lhsAtOutType =
 
 4400      helper.buildTypeFn(castSignedness, outType, block.
getArgument(0));
 
 4401  Value rhsAtOutType =
 
 4402      helper.buildTypeFn(castSignedness, outType, block.
getArgument(1));
 
 4403  Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
 
 4405  if (!productAtOutType)
 
 4411  helper.yieldOutputs({
result});
 
 4414ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 4416  if (
failed(indexingMapsAttr) || *indexingMapsAttr == 
nullptr)
 
 4418                            "expected 'indexing_maps' attribute");
 
 4419  result.addAttribute(
"indexing_maps", *indexingMapsAttr);
 
 4425void ContractOp::print(OpAsmPrinter &p) {
 
 4426  p << 
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
 
 4428      p, getOperation(), getInputs(), getOutputs(),
 
 4429      {
"indexing_maps", 
"operandSegmentSizes"});
 
 4432LogicalResult ContractOp::verify() {
 
 4433  int iterationSpaceDims = -1;
 
 4438  SmallVector<size_t> inOccurrences;
 
 4439  SmallVector<size_t> outOccurrences;
 
 4442  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
 
 4443                                   bool isInput) -> LogicalResult {
 
 4446      return emitError(
"provided affine_map is not a projected permutation");
 
 4449    if (
auto shapedType = dyn_cast<ShapedType>(operandType)) {
 
 4451        return emitError(
"ranks of shaped operand and results of corresponding " 
 4452                         "affine_map differ");
 
 4454      return emitError(
"affine_map specifies shaped access while operand has " 
 4459    if (iterationSpaceDims == -1) {
 
 4461      inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
 
 4462      outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
 
 4463    } 
else if (iterationSpaceDims != (
int)affineMap.
getNumDims()) {
 
 4464      return emitError(
"iteration spaces of provided affine_maps differ");
 
 4468    for (AffineExpr affineExpr : affineMap.
getResults()) {
 
 4469      auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
 
 4471        llvm_unreachable(
"affine_map is a projected permutation");
 
 4474        inOccurrences[affineDimExpr.getPosition()] += 1;
 
 4476        outOccurrences[affineDimExpr.getPosition()] += 1;
 
 4482  for (
auto &&[affineMap, operandType, isInput] :
 
 4483       llvm::zip(getIndexingMapsArray(), getOperandTypes(),
 
 4484                 SmallVector<bool>{
true, 
true, 
false})) {
 
 4485    if (
failed(checkAffineMapAndType(affineMap, operandType, isInput)))
 
 4489  bool hasContractingDim = 
false;
 
 4490  for (
size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
 
 4491    size_t inOccCount = inOccurrences[dimIndex];
 
 4492    size_t outOccCount = outOccurrences[dimIndex];
 
 4495    hasContractingDim |= inOccCount == 2 && outOccCount == 0;
 
 4497    if (inOccCount == 0 && outOccCount == 0)
 
 4498      return emitError() << 
"iteration space dim at index " << dimIndex
 
 4499                         << 
" not used to access any operand";
 
 4510    if (inOccCount == 1 && outOccCount != 1)
 
 4512             << 
"iteration space dim at index " << dimIndex
 
 4513             << 
" is neither a contracting dim nor of parallel iteration type";
 
 4516  if (!hasContractingDim)
 
 4517    return emitError(
"'indexing_maps' do not specify a contracting dimension");
 
 4522LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
 
 4526void ContractOp::getEffects(
 
 4527    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 4529  if (hasPureTensorSemantics())
 
 4541SmallVector<AffineMap>
 
 4542BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
 
 4543  AffineExpr d0, d1, d2, d3;
 
 4544  SmallVector<AffineMap> indexingMaps;
 
 4546  indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
 
 4547  indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
 
 4548  indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d2}, context));
 
 4549  return indexingMaps;
 
 4552bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
 
 4553  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 4556  if (maps.size() != 3)
 
 4561  return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
 
 4562         (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
 
 4563         (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
 
 4566SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
 
 4567  return SmallVector<utils::IteratorType>{
 
 4568      utils::IteratorType::parallel, utils::IteratorType::parallel,
 
 4569      utils::IteratorType::parallel, utils::IteratorType::reduction};
 
 4572unsigned BatchMatmulOp::getNumRegionArgs() { 
return 3; }
 
 4574std::string BatchMatmulOp::getLibraryCallName() {
 
 4580bool BatchMatmulOp::hasUserDefinedMaps() {
 
 4581  SmallVector<AffineMap, 3> defaultMaps =
 
 4583  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
 
 4584  return defaultMaps != explicitMaps;
 
 4594bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, 
bool isLHS) {
 
 4596         "Expected less than 3 result dim expr.");
 
 4597  bool isValid = 
false;
 
 4598  enum Indices { batchPos, mPos, nPos, kPos };
 
 4600    AffineExpr expr = bcastMap.
getResult(0);
 
 4603    AffineExpr expr0 = bcastMap.
getResult(0);
 
 4604    AffineExpr expr1 = bcastMap.
getResult(1);
 
 4609              : ((expr0.isFunctionOfDim(batchPos) &&
 
 4610                  expr1.isFunctionOfDim(kPos)) ||
 
 4611                 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
 
 4616void BatchMatmulOp::regionBuilder(
 
 4617    ImplicitLocOpBuilder &
b, 
Block &block, ArrayRef<NamedAttribute> attrs,
 
 4620    emitError() << 
"BatchMatmulOp regionBuilder expects 3 args, got " 
 4625         "BatchMatmulOp regionBuilder expects 3 args");
 
 4626  RegionBuilderHelper helper(
b, block);
 
 4627  SmallVector<Value> yields;
 
 4629  TypeFn castVal = TypeFn::cast_signed;
 
 4630  auto castIter = llvm::find_if(attrs, [&](
const NamedAttribute &attr) {
 
 4631    return attr.
getName() == 
"cast";
 
 4633  if (castIter != attrs.end()) {
 
 4634    if (
auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
 
 4639  Value castValA = helper.buildTypeFn(castVal, toType, block.
getArgument(0));
 
 4640  Value castValB = helper.buildTypeFn(castVal, toType, block.
getArgument(1));
 
 4641  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
 
 4643      helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
 
 4644  yields.push_back(addVal);
 
 4645  helper.yieldOutputs(yields);
 
 4648ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 4649  SmallVector<Attribute, 3> indexingMapsAttr;
 
 4661      if (!isa<AffineMapAttr>(mapAttr)) {
 
 4663                                "expected affine map attribute");
 
 4665      indexingMapsAttr.push_back(mapAttr);
 
 4675  if (indexingMapsAttr.empty()) {
 
 4676    indexingMapsAttr = llvm::map_to_vector(
 
 4677        BatchMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
 
 4678        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
 
 4680  result.addAttribute(
"indexing_maps",
 
 4683  return ::parseNamedStructuredOp(parser, 
result,
 
 4684                                  BatchMatmulOp::getNumRegionArgs(),
 
 4685                                  BatchMatmulOp::getRegionBuilder());
 
 4688void BatchMatmulOp::print(OpAsmPrinter &p) {
 
 4689  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
 
 4690      BatchMatmulOp::getDefaultIndexingMaps(
getContext()),
 
 4691      [](AffineMap map) -> Attribute { 
return AffineMapAttr::get(map); });
 
 4692  if (!llvm::equal(getIndexingMaps(), indexingMaps))
 
 4693    p << 
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
 
 4695  std::array<StringRef, 3> elidedAttrs = {
 
 4696      "operandSegmentSizes", 
"linalg.memoized_indexing_maps", 
"indexing_maps"};
 
 4702LogicalResult BatchMatmulOp::verify() {
 
 4705  if (!hasUserDefinedMaps())
 
 4708  for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
 
 4715LogicalResult BatchMatmulOp::fold(FoldAdaptor,
 
 4716                                  SmallVectorImpl<OpFoldResult> &) {
 
 4720void BatchMatmulOp::getEffects(
 
 4721    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 4723  if (hasPureTensorSemantics())
 
 4737struct ArityGroupAndKind {
 
 4739  ElementwiseArityGroup arityGroup;
 
 4745    TernaryFn ternaryFn;
 
 4749unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
 
 4750  return static_cast<unsigned>(arityGroup);
 
 4755  constexpr int lastUnary = 
static_cast<int>(ElementwiseCaseLimits::LastUnary);
 
 4756  constexpr int lastBinary =
 
 4757      static_cast<int>(ElementwiseCaseLimits::LastBinary);
 
 4758  constexpr int lastTernary =
 
 4759      static_cast<int>(ElementwiseCaseLimits::LastTernary);
 
 4761  int val = 
static_cast<int>(kind);
 
 4762  ArityGroupAndKind 
result;
 
 4764  if (val < lastUnary) {
 
 4765    result.arityGroup = ElementwiseArityGroup::Unary;
 
 4766    result.kind.unaryFn = 
static_cast<UnaryFn
>(val);
 
 4769  if (val < lastBinary) {
 
 4770    result.arityGroup = ElementwiseArityGroup::Binary;
 
 4771    result.kind.binaryFn = 
static_cast<BinaryFn
>(val - lastUnary);
 
 4774  if (val >= lastTernary) {
 
 4775    llvm_unreachable(
"unhandled ElementwiseFn");
 
 4777  result.arityGroup = ElementwiseArityGroup::Ternary;
 
 4778  result.kind.ternaryFn = 
static_cast<TernaryFn
>(val - lastBinary);
 
 
 4783  auto rank = getResultRank();
 
 4788ElementwiseOp::getDefaultIndexingMaps(
unsigned numMaps, 
unsigned numDims,
 
 4794ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &
result) {
 
 4797  mlir::linalg::ElementwiseKind elemwiseKindVal;
 
 4802    auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
 
 4803    if (!elemwiseKindAttr)
 
 4805                              "expected ElementwiseKind attribute");
 
 4806    elemwiseKindVal = elemwiseKindAttr.getValue();
 
 4809                            "expected operation 'kind' attribute");
 
 4812      "kind", ElementwiseKindAttr::get(parser.
getContext(), elemwiseKindVal));
 
 4815  SmallVector<Attribute, 3> indexingMapsAttr;
 
 4825      if (!isa<AffineMapAttr>(mapAttr))
 
 4827                                "expected affine map attribute");
 
 4828      indexingMapsAttr.push_back(mapAttr);
 
 4839      getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 ;
 
 4841                             ElementwiseOp::getRegionBuilder())) {
 
 4843                            "unable to parse elemwise op");
 
 4847  if (indexingMapsAttr.empty()) {
 
 4850    auto resultType = 
result.operands[
result.operands.size() - 1].getType();
 
 4851    auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
 
 4854                              "return type needs to be shaped type");
 
 4855    auto numDims = shapedType.getRank();
 
 4856    indexingMapsAttr = llvm::map_to_vector(
 
 4857        ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
 
 4859        [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
 
 4862  result.addAttribute(
"indexing_maps",
 
 4867void ElementwiseOp::print(OpAsmPrinter &p) {
 
 4870  SmallVector<StringRef, 3> elidedAttrs = {
"operandSegmentSizes", 
"kind",
 
 4874  unsigned numDims = getResultRank();
 
 4876  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
 
 4877      ElementwiseOp::getDefaultIndexingMaps(arity + 1 , numDims,
 
 4879      [](AffineMap map) -> Attribute { 
return AffineMapAttr::get(map); });
 
 4881  if (!llvm::equal(getIndexingMaps(), indexingMaps))
 
 4882    p << 
" indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
 
 4888LogicalResult ElementwiseOp::verify() {
 
 4897void ElementwiseOp::regionBuilder(
 
 4898    ImplicitLocOpBuilder &
b, 
Block &block, ArrayRef<NamedAttribute> attrs,
 
 4900  ElementwiseKind elemwiseKind;
 
 4901  for (
auto attr : attrs) {
 
 4902    if (attr.getName() == 
b.getStringAttr(
"kind")) {
 
 4903      auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
 
 4904      assert(kindAttr && 
"op kind attribute incorrectly set");
 
 4905      elemwiseKind = kindAttr.getValue();
 
 4911  auto arityGroup = groupAndKind.arityGroup;
 
 4912  auto kind = groupAndKind.kind;
 
 4914                       getArityGroupAsUInt(arityGroup) + 1 ) {
 
 4915    emitError() << 
"Elementwise regionBuilder expects " 
 4916                << (getArityGroupAsUInt(arityGroup) + 1) << 
" args, got " 
 4921             getArityGroupAsUInt(arityGroup) + 1 
 
 4922         && 
"Elementwise regionBuilder number of block args mismatch");
 
 4924  RegionBuilderHelper helper(
b, block);
 
 4925  SmallVector<Value> yields;
 
 4928  if (arityGroup == ElementwiseArityGroup::Unary) {
 
 4931  } 
else if (arityGroup == ElementwiseArityGroup::Binary) {
 
 4935  } 
else if (arityGroup == ElementwiseArityGroup::Ternary) {
 
 4940    assert(
false && 
"found unhandled category in elemwise");
 
 4943  yields.push_back(
result);
 
 4944  helper.yieldOutputs(yields);
 
 4947LogicalResult ElementwiseOp::fold(FoldAdaptor,
 
 4948                                  SmallVectorImpl<OpFoldResult> &) {
 
 4952void ElementwiseOp::getEffects(
 
 4953    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
 
 4955  if (hasPureTensorSemantics())
 
 4968template <
typename OpTy, 
typename>
 
 4971  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
 
 4972                                    ? packOrUnPack.getDestType()
 
 4973                                    : packOrUnPack.getSourceType();
 
 4974  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
 
 4975                                      ? packOrUnPack.getSourceType()
 
 4976                                      : packOrUnPack.getDestType();
 
 4978      packedType.getShape().take_front(unpackedType.getRank()));
 
 4979  if (!packOrUnPack.getOuterDimsPerm().empty()) {
 
 
 5001  for (
auto it : llvm::zip(cast<ShapedType>(newPackedTy)
 
 5003                               .take_back(mixedTiles.size()),
 
 5006    if (
shape == ShapedType::kDynamic) {
 
 5007      newMixedTileSizes.push_back(std::get<1>(it));
 
 5014    if (
Attribute attr = llvm::dyn_cast_if_present<Attribute>(
tile)) {
 
 5016      newMixedTileSizes.push_back(
tile);
 
 5019             "tile size and dim size don't match!");
 
 5020      newMixedTileSizes.push_back(
 
 5025  return newMixedTileSizes;
 
 
 5028template <
typename OpTy>
 
 5032  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5033                "applies to only pack or unpack operations");
 
 5034  int64_t destRank = op.getDestRank();
 
 5036  reifiedReturnShapes[0] =
 
 
 5041template <
typename OpTy>
 
 5043  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5044                "applies to only pack or unpack operations");
 
 5048  assert(tiles.size() == dimsToTile.size() &&
 
 5049         "tiles must match indices of dimension to block");
 
 5051  for (
auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
 
 5052    dimAndTileMapping[dimsToTile[i]] = tiles[i];
 
 5053  return dimAndTileMapping;
 
 
 5056template <
typename OpTy>
 
 5058  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5059                "applies to only pack or unpack operations");
 
 5062  unsigned dynamicValIndex = 0;
 
 5063  for (
int64_t staticTile : op.getStaticInnerTiles()) {
 
 5064    if (ShapedType::isStatic(staticTile))
 
 5067      mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
 
 5069  return mixedInnerTiles;
 
 
 5072template <
typename OpTy>
 
 5074  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5075                "applies to only pack or unpack operations");
 
 
 5088  size_t dimsPosSize = dimsPos.size();
 
 5089  if (dimsPosSize > rank)
 
 5092  if (dimsPosSize != uniqued.size())
 
 5094  return llvm::any_of(dimsPos, [rank](
int64_t dimPos) {
 
 5095    return dimPos < 0 || dimPos >= 
static_cast<int64_t>(rank);
 
 
 5099template <
typename OpTy>
 
 5101  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5102                "applies to only pack or unpack operations");
 
 5103  Operation *op = packOrUnPack.getOperation();
 
 5112  if (hasZeros(mixedTiles))
 
 5113    return op->
emitError(
"invalid zero tile factor");
 
 5116  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
 
 5117                                      ? packOrUnPack.getSourceType()
 
 5118                                      : packOrUnPack.getDestType();
 
 5119  size_t unpackedRank = unpackedType.getRank();
 
 5123    return op->
emitError(
"invalid inner_dims_pos vector");
 
 5125    return op->
emitError(
"invalid outer_dims_perm vector");
 
 5126  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
 
 5127    return op->
emitError(
"outer_dims_perm must be a permutation or empty");
 
 5131  if (mixedTiles.size() > unpackedRank) {
 
 5132    return op->
emitError(
"tiling factors must be less than or equal to the " 
 5133                         "input rank for pack or output rank for unpack");
 
 5135  if (mixedTiles.size() != innerDimsPos.size()) {
 
 5137        "tiling factors must equal the number of dimensions to tile");
 
 5140  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
 
 5141                              ? packOrUnPack.getDestType()
 
 5142                              : packOrUnPack.getSourceType();
 
 5143  size_t packedRank = packedType.getRank();
 
 5145  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
 
 5146  if (expectedPackedRank != packedRank) {
 
 5148               "packed rank != (unpacked rank + num tiling factors), got ")
 
 5149           << packedRank << 
" != " << expectedPackedRank;
 
 5155  RankedTensorType expectedPackedType = PackOp::inferPackedType(
 
 5156      unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
 
 5158          llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
 
 5160          [](std::tuple<int64_t, OpFoldResult> it) {
 
 5161            int64_t shape = std::get<0>(it);
 
 5162            if (Attribute attr =
 
 5163                    llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
 
 5164              IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
 
 5165              int64_t staticTileSize = intAttr.getValue().getSExtValue();
 
 5166              return shape == staticTileSize;
 
 5168            return ShapedType::isDynamic(
shape);
 
 5170    return op->emitError(
"mismatch in inner tile sizes specified and shaped of " 
 5171                         "tiled dimension in the packed type");
 
 5174                                   packedType.getShape()))) {
 
 5175    return op->emitError(
"expected ")
 
 5176           << expectedPackedType << 
" for the packed domain value, got " 
 
 5189struct PackOrUnPackTransposeResult {
 
 5196template <
typename OpTy>
 
 5197static PackOrUnPackTransposeResult
 
 5201  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5202                "applies to only pack or unpack operations");
 
 5203  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
 
 5204         "some permutation must be non-empty");
 
 5205  PackOrUnPackTransposeResult metadata;
 
 5206  metadata.innerDimsPos =
 
 5208  metadata.innerTiles =
 
 5210  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
 
 5211                             ? packOrUnPackOp.getSourceRank()
 
 5212                             : packOrUnPackOp.getDestRank();
 
 5213  metadata.outerDimsPerm =
 
 5214      packOrUnPackOp.getOuterDimsPerm().empty()
 
 5215          ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
 
 5217  if (!innerPermutation.empty()) {
 
 5218    assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
 
 5220           "invalid inner permutation");
 
 5224  if (!outerPermutation.empty()) {
 
 5225    assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
 
 5227           "invalid outer permutation");
 
 
 5238  setNameFn(getResult(), 
"pack");
 
 5244                   std::optional<Value> paddingValue,
 
 5246  assert(innerDimsPos.size() == innerTiles.size() &&
 
 5247         "number of tile sizes specified must match the specified number of " 
 5248         "original dimensions to be tiled");
 
 5252  build(builder, state, dest.
getType(), source, dest,
 
 5253        paddingValue ? *paddingValue : 
nullptr,
 
 5254        outerDimsPerm.empty() ? 
nullptr 
 5261PackOp::reifyResultShapes(
OpBuilder &builder,
 
 5279  ShapedType inputType = getSourceType();
 
 5280  int64_t inputRank = inputType.getRank();
 
 5281  return getDestType().getShape().take_front(inputRank);
 
 5285  auto innerDimsPos = getInnerDimsPos();
 
 5292  if (!outerDimPermInv.empty())
 
 5296  for (
auto index : innerDimsPos)
 
 5297    res.push_back(outerDims[
index]);
 
 5308      outputShape.take_front(inputShape.size()));
 
 5309  if (!outerDimsPerm.empty()) {
 
 5310    assert(outerDimsPerm.size() == outputTileSizes.size() &&
 
 5311           "expected output and outer_dims_perm to have same size");
 
 5315  for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
 
 5316    if (ShapedType::isDynamic(inputShape[pos]))
 
 5320    if (!constantTile) {
 
 5321      if (ShapedType::isStatic(outputTileSizes[pos]) &&
 
 5322          (inputShape[pos] % outputTileSizes[pos] != 0))
 
 5324    } 
else if (inputShape[pos] % (*constantTile) != 0) {
 
 5337      outputShape.take_front(inputShape.size()));
 
 5338  if (!outerDimsPerm.empty()) {
 
 5339    assert(outerDimsPerm.size() == outputTileSizes.size() &&
 
 5340           "expected output and outer_dims_perm to have same size");
 
 5344  for (
auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
 
 5345    if (ShapedType::isDynamic(inputShape[pos]) ||
 
 5346        ShapedType::isDynamic(outputTileSizes[pos]))
 
 5351    if (inputShape[pos] % (*constantTile) != 0)
 
 5357LogicalResult PackOp::verify() {
 
 5364  auto paddingValue = getPaddingValue();
 
 5368           << getSourceType().getElementType()
 
 5369           << 
" but got: " << paddingValue.getType();
 
 5372  if (!paddingValue &&
 
 5373      requirePaddingValue(getSourceType().
getShape(), getInnerDimsPos(),
 
 5374                          getDestType().
getShape(), getOuterDimsPerm(),
 
 5377        "invalid tile factor or output size provided. Only full tiles are " 
 5378        "supported when padding_value is not set");
 
 5388  for (
auto o : ofrs) {
 
 5390    if (llvm::dyn_cast_if_present<Value>(o))
 
 5391      result.push_back(ShapedType::kDynamic);
 
 
 5405  for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
 
 5406    if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
 
 5408    if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
 
 5409      resultShape[tiledDim.value()] = ShapedType::kDynamic;
 
 5412    resultShape[tiledDim.value()] = llvm::divideCeilSigned(
 
 5413        resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
 
 5417  if (!outerDimsPerm.empty())
 
 5421  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
 
 
 5434  for (
auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
 
 5436        builder, loc, ceilDivExpr,
 
 5437        {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
 
 5439  if (!outerDimsPerm.empty())
 
 5441  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
 
 5446                               innerDimsPos, outerDimsPerm);
 
 5452  for (
unsigned i = 0; i < resultDims.size(); ++i) {
 
 5453    if (ShapedType::isStatic(resultTypeShape[i]))
 
 5464RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
 
 5469      sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
 
 5470  return RankedTensorType::get(resultShape, sourceType.getElementType());
 
 5485  for (
auto [
index, value] : llvm::enumerate(
 
 5486           llvm::cast<RankedTensorType>(source.
getType()).getShape())) {
 
 5487    if (ShapedType::isDynamic(value))
 
 5488      mixedSizes.push_back(
 
 5489          tensor::DimOp::create(
b, loc, source, 
index).getResult());
 
 5491      mixedSizes.push_back(
b.getIndexAttr(value));
 
 5493  for (
auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
 
 5494    int64_t dimPos = std::get<0>(it);
 
 5496    mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
 
 5498  if (!outerDimsPerm.empty())
 
 5501  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
 
 5502  auto elemType = llvm::cast<ShapedType>(source.
getType()).getElementType();
 
 5503  return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
 
 5510      *
this, innerPermutation, outerPermutation);
 
 5511  Value transposedDest =
 
 5512      createDestinationTensor(
b, loc, getSource(), metadata.innerTiles,
 
 5513                              metadata.innerDimsPos, metadata.outerDimsPerm);
 
 5514  return PackOp::create(
b, loc, getSource(), transposedDest,
 
 5515                        metadata.innerDimsPos, metadata.innerTiles,
 
 5516                        getPaddingValue(), metadata.outerDimsPerm);
 
 5520template <
typename OpTy>
 
 5522  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
 
 5523                "applies to only pack or unpack operations");
 
 5524  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
 
 5526                              : op.getSourceType();
 
 5528  for (
auto [dimDest, 
tile] : llvm::zip(
 
 5529           packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
 
 5531    if (!constTileSize || ShapedType::isDynamic(dimDest))
 
 
 5538  if (getPaddingValue())
 
 5553  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
 
 5555  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
 
 
 5567  auto packTiles = packOp.getMixedTiles();
 
 5568  auto unPackTiles = unPackOp.getMixedTiles();
 
 5569  if (packTiles.size() != unPackTiles.size())
 
 5571  for (
size_t i = 0, e = packTiles.size(); i < e; i++) {
 
 
 5580  auto srcType = op.getSourceType();
 
 5581  if (llvm::any_of(op.getInnerDimsPos(),
 
 5582                   [&](
int64_t pos) { return srcType.isDynamicDim(pos); }))
 
 5584  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
 
 5586  return !PackOp::requirePaddingValue(
 
 5587      srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
 
 5588      op.getOuterDimsPerm(), op.getMixedTiles());
 
 
 5595  bool changeNeeded = 
false;
 
 5596  srcShape.assign(packOp.getSourceType().getShape().begin(),
 
 5597                  packOp.getSourceType().getShape().end());
 
 5598  destShape.assign(packOp.getDestType().getShape().begin(),
 
 5599                   packOp.getDestType().getShape().end());
 
 5600  llvm::SmallSetVector<int64_t, 4> innerDims;
 
 5601  innerDims.insert_range(packOp.getInnerDimsPos());
 
 5603  if (!packOp.getOuterDimsPerm().empty())
 
 5605  int srcRank = packOp.getSourceRank();
 
 5606  for (
auto i : llvm::seq<int64_t>(0, srcRank)) {
 
 5607    if (innerDims.contains(i))
 
 5611    if (!inverseOuterDimsPerm.empty())
 
 5612      destPos = inverseOuterDimsPerm[srcPos];
 
 5613    if (ShapedType::isDynamic(srcShape[srcPos]) ==
 
 5614        ShapedType::isDynamic(destShape[destPos])) {
 
 5617    int64_t size = srcShape[srcPos];
 
 5618    if (ShapedType::isDynamic(size))
 
 5619      size = destShape[destPos];
 
 5620    srcShape[srcPos] = size;
 
 5621    destShape[destPos] = size;
 
 5622    changeNeeded = 
true;
 
 5624  return changeNeeded;
 
 
 5627LogicalResult PackOp::canonicalize(PackOp packOp, 
PatternRewriter &rewriter) {
 
 5629  if (
auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
 
 5630    if (unPackOp.getSourceType() == packOp.getDestType() &&
 
 5631        !packOp.getPaddingValue() &&
 
 5634      rewriter.
replaceOp(packOp, unPackOp.getSource());
 
 5642    packOp.getPaddingValueMutable().clear();
 
 5651    Value source = packOp.getSource();
 
 5652    if (srcShape != packOp.getSourceType().getShape()) {
 
 5653      auto newSrcType = packOp.getSourceType().clone(srcShape);
 
 5655          tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
 
 5657    Value dest = packOp.getDest();
 
 5658    RankedTensorType originalResultType = packOp.getDestType();
 
 5659    bool needUpdateDestType = (destShape != originalResultType.getShape());
 
 5660    if (needUpdateDestType) {
 
 5661      auto newDestType = packOp.getDestType().clone(destShape);
 
 5663          tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
 
 5666      packOp.getSourceMutable().assign(source);
 
 5667      packOp.getDestMutable().assign(dest);
 
 5668      packOp.getResult().setType(cast<RankedTensorType>(dest.
getType()));
 
 5671    if (needUpdateDestType) {
 
 5674          tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
 
 5683template <
typename PackOrUnpackOp>
 
 5685                           RankedTensorType packedTensorType) {
 
 5686  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
 
 5687                    std::is_same<PackOrUnpackOp, UnPackOp>::value,
 
 5688                "Function meant for pack/unpack");
 
 5693  int64_t numPackedDims = innerDimsPos.size();
 
 5694  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
 
 5695  if (orderedDims != innerDimsPos) {
 
 5701  int64_t packedRank = packedTensorType.getRank();
 
 5711  return llvm::all_of(
 
 5712      llvm::seq<int64_t>(0, packedRank - numPackedDims),
 
 5713      [&packedShape](
int64_t i) { 
return packedShape[i] == 1; });
 
 
 5716bool PackOp::isLikePad() {
 
 5717  auto packedTensorType =
 
 5718      llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
 
 5723  std::optional<Attribute> paddingValue;
 
 5724  if (
auto pad = adaptor.getPaddingValue())
 
 5726  if (
OpFoldResult reshapedSource = reshapeConstantSource(
 
 5727          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
 
 5728          getDestType(), paddingValue))
 
 5729    return reshapedSource;
 
 5768        PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
 
 5769                       op.getInnerDimsPos(), newMixedTileSizes,
 
 5770                       op.getPaddingValue(), op.getOuterDimsPerm());
 
 5771    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
 5774    Value oldResult = op.getResult();
 
 5775    Value newResult = newOp.getResult();
 
 5778            ? tensor::CastOp::create(rewriter, op->getLoc(),
 
 5779                                     oldResult.
getType(), newResult)
 
 
 
 5792void UnPackOp::getAsmResultNames(
 
 5794  setNameFn(getResult(), 
"unpack");
 
 5798UnPackOp::reifyResultShapes(
OpBuilder &builder,
 
 5816  ShapedType destType = getDestType();
 
 5817  int64_t destRank = destType.getRank();
 
 5818  return getSourceType().getShape().take_front(destRank);
 
 5822  auto innerDimsPos = getInnerDimsPos();
 
 5829  if (!outerDimPermInv.empty())
 
 5833  for (
auto index : innerDimsPos)
 
 5834    res.push_back(outerDims[
index]);
 
 5839LogicalResult UnPackOp::verify() {
 
 5855  assert(innerDimsPos.size() == innerTiles.size() &&
 
 5856         "number of tile sizes specified must match the specified number of " 
 5857         "original dimensions to be tiled");
 
 5861  build(builder, state, dest.
getType(), source, dest,
 
 5862        outerDimsPerm.empty() ? 
nullptr 
 5880  auto srcType = llvm::cast<RankedTensorType>(source.
getType());
 
 5882       llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
 
 5883    if (srcType.isDynamicDim(i))
 
 5884      mixedSizes.push_back(
 
 5885          tensor::DimOp::create(
b, loc, source, i).getResult());
 
 5887      mixedSizes.push_back(
b.getIndexAttr(srcType.getDimSize(i)));
 
 5889  if (!outerDimsPerm.empty()) {
 
 5894  for (
auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
 
 5895    mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
 
 5897  auto elemType = srcType.getElementType();
 
 5898  return tensor::EmptyOp::create(
b, loc, mixedSizes, elemType);
 
 5902                                         Value transposedSource,
 
 5906      *
this, innerPermutation, outerPermutation);
 
 5907  return UnPackOp::create(
b, loc, transposedSource, getDest(),
 
 5908                          metadata.innerDimsPos, metadata.innerTiles,
 
 5909                          metadata.outerDimsPerm);
 
 5916  bool changeNeeded = 
false;
 
 5917  srcShape.assign(op.getSourceType().getShape().begin(),
 
 5918                  op.getSourceType().getShape().end());
 
 5919  destShape.assign(op.getDestType().getShape().begin(),
 
 5920                   op.getDestType().getShape().end());
 
 5921  llvm::SmallSetVector<int64_t, 4> innerDims;
 
 5922  innerDims.insert_range(op.getInnerDimsPos());
 
 5924  if (!op.getOuterDimsPerm().empty())
 
 5926  int destRank = op.getDestRank();
 
 5927  for (
auto i : llvm::seq<int64_t>(0, destRank)) {
 
 5928    if (innerDims.contains(i))
 
 5932    if (!inverseOuterDimsPerm.empty())
 
 5933      srcPos = inverseOuterDimsPerm[destPos];
 
 5934    if (ShapedType::isDynamic(srcShape[srcPos]) ==
 
 5935        ShapedType::isDynamic(destShape[destPos])) {
 
 5938    int64_t size = srcShape[srcPos];
 
 5939    if (ShapedType::isDynamic(size))
 
 5940      size = destShape[destPos];
 
 5941    srcShape[srcPos] = size;
 
 5942    destShape[destPos] = size;
 
 5943    changeNeeded = 
true;
 
 5945  return changeNeeded;
 
 
 5948LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
 
 5951  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
 
 5952    if (packOp.getSourceType() != unPackOp.getDestType())
 
 5954    if (packOp.getPaddingValue() ||
 
 5958    rewriter.
replaceOp(unPackOp, packOp.getSource());
 
 5962  if (
auto dstStyleOp =
 
 5963          unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
 
 5964    auto destValue = cast<OpResult>(unPackOp.getDest());
 
 5965    Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
 
 5967                             [&]() { unPackOp.setDpsInitOperand(0, newDest); });
 
 5971  if (unPackOp->hasOneUse()) {
 
 5972    auto extractSliceUser =
 
 5973        dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
 
 5974    if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
 
 5977      auto newDest = tensor::ExtractSliceOp::create(
 
 5978          rewriter, unPackOp->getLoc(), unPackOp.getDest(),
 
 5979          extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
 
 5980          extractSliceUser.getMixedStrides());
 
 5982        unPackOp.setDpsInitOperand(0, newDest);
 
 5983        unPackOp.getResult().setType(newDest.
getType());
 
 5985      rewriter.
replaceOp(extractSliceUser, unPackOp);
 
 5994    Value source = unPackOp.getSource();
 
 5995    if (srcShape != unPackOp.getSourceType().getShape()) {
 
 5996      auto newSrcType = unPackOp.getSourceType().clone(srcShape);
 
 5997      source = tensor::CastOp::create(rewriter, loc, newSrcType,
 
 5998                                      unPackOp.getSource());
 
 6000    Value dest = unPackOp.getDest();
 
 6001    if (destShape != unPackOp.getDestType().getShape()) {
 
 6002      auto newDestType = unPackOp.getDestType().clone(destShape);
 
 6003      dest = tensor::CastOp::create(rewriter, loc, newDestType,
 
 6004                                    unPackOp.getDest());
 
 6006    Value newOp = UnPackOp::create(
 
 6007        rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
 
 6008        unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
 
 6010        unPackOp, unPackOp.getResult().
getType(), newOp);
 
 6017bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
 
 6019  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
 
 6024  RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
 
 6027  for (
auto [pos, tileSize] :
 
 6028       llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
 
 6029    if (unpackedTypeAfterFold.isDynamicDim(pos))
 
 6031    if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
 
 6033    if (ShapedType::isDynamic(tileSize))
 
 6035    int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
 
 6036                          unpackedTypeAfterFold.getDimSize(pos);
 
 6037    if (paddingSize >= tileSize)
 
 6043bool UnPackOp::isLikeUnPad() {
 
 6044  RankedTensorType packedTensorType = getSourceType();
 
 6049  if (
OpFoldResult reshapedSource = reshapeConstantSource(
 
 6050          llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
 
 6052    return reshapedSource;
 
 6081    Value sourceTensor = newOperands[0];
 
 6085        rewriter, sourceTensor.
getType(), op.getMixedTiles());
 
 6091    UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
 
 6092                                      newOperands[1], op.getInnerDimsPos(),
 
 6093                                      newMixedTileSizes, op.getOuterDimsPerm());
 
 6094    newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
 
 6097    Value oldResult = op.getResult();
 
 6098    Value newResult = newOp.getResult();
 
 6101            ? tensor::CastOp::create(rewriter, op->getLoc(),
 
 6102                                     oldResult.
getType(), newResult)
 
 
 
 6116      utils::IteratorType::reduction, utils::IteratorType::parallel,
 
 6117      utils::IteratorType::parallel, utils::IteratorType::reduction};
 
 6121BatchReduceMatmulOp::getDefaultIndexingMaps(
MLIRContext *context) {
 
 6125  indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d1, d3}, context));
 
 6126  indexingMaps.push_back(
AffineMap::get(4, 0, {d0, d3, d2}, context));
 
 6128  return indexingMaps;
 
 6131bool BatchReduceMatmulOp::isDefaultIndexingMaps(
Attribute attr) {
 
 6132  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
 
 6135  if (maps.size() != 3)
 
 6144unsigned BatchReduceMatmulOp::getNumRegionArgs() { 
return 3; }
 
 6146std::string BatchReduceMatmulOp::getLibraryCallName() {
 
 6152bool BatchReduceMatmulOp::hasUserDefinedMaps() {
 
 6156  return defaultMaps != explicitMaps;
 
 6166bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(
AffineMap bcastMap,
 
 6169         "Expected less than 3 result dim expr.");
 
 6170  bool isValid = 
false;
 
 6171  enum Indices { batchPos, mPos, nPos, kPos };
 
 6182              : ((expr0.isFunctionOfDim(batchPos) &&
 
 6183                  expr1.isFunctionOfDim(kPos)) ||
 
 6184                 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
 
 6189void BatchReduceMatmulOp::regionBuilder(
 
 6193    emitError() << 
"BatchReduceMatmulOp regionBuilder expects 3 args, got " 
 6198         "BatchReduceMatmulOp regionBuilder expects 3 args");
 
 6199  RegionBuilderHelper helper(
b, block);
 
 6204      helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(0));
 
 6206      helper.buildTypeFn(TypeFn::cast_signed, toType, block.
getArgument(1));
 
 6207  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
 
 6209      helper.buildBinaryFn(BinaryFn::add, block.
getArgument(2), mulVal);
 
 6210  yields.push_back(addVal);
 
 6211  helper.yieldOutputs(yields);
 
 6214ParseResult BatchReduceMatmulOp::parse(
OpAsmParser &parser,
 
 6227      if (!isa<AffineMapAttr>(mapAttr)) {
 
 6229                                "expected affine map attribute");
 
 6231      indexingMapsAttr.push_back(mapAttr);
 
 6241  if (indexingMapsAttr.empty()) {
 
 6242    indexingMapsAttr = llvm::map_to_vector(
 
 6243        BatchReduceMatmulOp::getDefaultIndexingMaps(parser.
getContext()),
 
 6246  result.addAttribute(
"indexing_maps",
 
 6248  return ::parseNamedStructuredOp(parser, 
result,
 
 6249                                  BatchReduceMatmulOp::getNumRegionArgs(),
 
 6250                                  BatchReduceMatmulOp::getRegionBuilder());
 
 6255      BatchReduceMatmulOp::getDefaultIndexingMaps(
getContext()),
 
 6258  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
 
 6259    p << 
" indexing_maps = [";
 
 6260    llvm::interleaveComma(getIndexingMaps(), p,
 
 6266      "operandSegmentSizes", 
"linalg.memoized_indexing_maps", 
"indexing_maps"};
 
 6272LogicalResult BatchReduceMatmulOp::verify() {
 
 6275  if (!hasUserDefinedMaps())
 
 6278  for (
unsigned opIndex = 0; opIndex < 3; opIndex++) {
 
 6284LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
 
 6288void BatchReduceMatmulOp::getEffects(
 
 6291  if (hasPureTensorSemantics())
 
 6307void LinalgDialect::getCanonicalizationPatterns(
 
 6316  return arith::ConstantOp::materialize(builder, value, type, loc);
 
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
 
static Type getElementType(Type type)
Determine the element type of type.
 
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
 
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
 
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs)
 
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
 
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
 
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
 
static bool canUseShortForm(Block *body, bool initFirst=false, bool mapInit=true)
 
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
 
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
 
llvm::function_ref< void( ImplicitLocOpBuilder &, Block &, ArrayRef< NamedAttribute >, function_ref< InFlightDiagnostic()>)> RegionBuilderFn
 
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
 
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
 
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
 
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
 
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
 
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
 
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
 
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
 
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
 
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
 
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
 
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
 
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
 
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
 
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, LinalgOp linalgOp)
 
static void buildGenericRegion(OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
 
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
 
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
 
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false, bool mapInit=true)
 
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
 
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
 
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
 
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
 
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
 
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
 
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
 
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
 
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
 
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > givenTileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
 
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
 
Base type for affine expression.
 
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
 
AffineExpr ceilDiv(uint64_t v) const
 
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
 
AffineMap dropResults(ArrayRef< int64_t > positions) const
 
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
 
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
 
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
 
unsigned getNumDims() const
 
ArrayRef< AffineExpr > getResults() const
 
unsigned getNumResults() const
 
AffineExpr getResult(unsigned idx) const
 
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
 
@ Paren
Parens surrounding zero or more operands.
 
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
 
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
 
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
 
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
 
MLIRContext * getContext() const
 
virtual ParseResult parseRParen()=0
Parse a ) token.
 
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
 
virtual ParseResult parseLSquare()=0
Parse a [ token.
 
virtual ParseResult parseRSquare()=0
Parse a ] token.
 
virtual ParseResult parseRBrace()=0
Parse a } token.
 
virtual ParseResult parseEqual()=0
Parse a = token.
 
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
 
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
 
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
 
virtual ParseResult parseGreater()=0
Parse a '>' token.
 
virtual ParseResult parseLParen()=0
Parse a ( token.
 
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
 
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
 
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
 
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
 
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
 
virtual void printAttribute(Attribute attr)
 
Attributes are known-constant values of operations.
 
Block represents an ordered list of Operations.
 
BlockArgument getArgument(unsigned i)
 
unsigned getNumArguments()
 
OpListType & getOperations()
 
Operation * getTerminator()
Get the terminator operation of this block.
 
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
 
BlockArgListType getArguments()
 
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
 
This class is a general helper class for creating context-global objects like types,...
 
IntegerAttr getIndexAttr(int64_t value)
 
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
 
IntegerAttr getIntegerAttr(Type type, int64_t value)
 
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
 
AffineMap getMultiDimIdentityMap(unsigned rank)
 
IntegerAttr getI64IntegerAttr(int64_t value)
 
StringAttr getStringAttr(const Twine &bytes)
 
AffineExpr getAffineDimExpr(unsigned position)
 
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
 
MLIRContext * getContext() const
 
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
 
IRValueT get() const
Return the current value being used by this operand.
 
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
 
This class represents a diagnostic that is inflight and set to be reported.
 
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
 
MLIRContext is the top-level object for a collection of MLIR operations.
 
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
 
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
 
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
 
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
 
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
 
NamedAttribute represents a combination of a name and an Attribute value.
 
StringAttr getName() const
Return the name of the attribute.
 
Attribute getValue() const
Return the value of the attribute.
 
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
 
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
 
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
 
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
 
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
 
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
 
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
 
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
 
virtual void increaseIndent()=0
Increase indentation.
 
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
 
virtual void decreaseIndent()=0
Decrease indentation.
 
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
 
RAII guard to reset the insertion point of the builder when destroyed.
 
This class helps build Operations.
 
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
 
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
 
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
 
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
 
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
 
This class represents a single result from folding an operation.
 
This class represents an operand of an operation.
 
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
 
unsigned getResultNumber() const
Returns the number of this result.
 
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
 
Operation is the basic unit of execution within MLIR.
 
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
 
result_iterator result_begin()
 
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
 
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
 
Location getLoc()
The source location the operation was defined or derived from.
 
unsigned getNumOperands()
 
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
 
OperationName getName()
The name of an operation is the key identifier for it.
 
operand_type_range getOperandTypes()
 
result_iterator result_end()
 
result_type_range getResultTypes()
 
operand_range getOperands()
Returns an iterator on the underlying Value's.
 
result_range getResults()
 
unsigned getNumResults()
Return the number of results held by this operation.
 
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
 
This class contains a list of basic blocks and a link to the parent operation it is attached to.
 
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
 
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
 
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
 
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
 
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
 
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
 
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
 
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
 
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
 
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
 
This class represents a specific instance of an effect.
 
static DerivedEffect * get()
 
static DefaultResource * get()
 
This class provides an abstraction over the various different ranges of value types.
 
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
 
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
 
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
 
This class provides an abstraction over the different types of ranges over Values.
 
type_range getTypes() const
 
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
 
Type getType() const
Return the type of this value.
 
Block * getParentBlock()
Return the Block in which this Value is defined.
 
bool hasOneUse() const
Returns true if this value has exactly one use.
 
Location getLoc() const
Return the location of this value.
 
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
 
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
 
static Attribute parse(AsmParser &parser, Type type)
 
Specialization of linalg.batch_matmul op that has a transpose map on A.
 
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
 
static bool classof(Operation *op)
 
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
 
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
 
Specialization of linalg.batch_matmul op that has a transpose map on B.
 
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
 
static bool classof(Operation *op)
 
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
 
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
 
Specialization of linalg.matmul op that has a transpose map on A.
 
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
 
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
 
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
 
static bool classof(Operation *op)
 
Specialization of linalg.matmul op that has a transpose map on B.
 
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
 
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
 
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
 
static bool classof(Operation *op)
 
constexpr auto RecursivelySpeculatable
 
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
 
constexpr auto Speculatable
 
constexpr auto NotSpeculatable
 
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
 
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
 
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
 
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
 
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
 
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
 
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
 
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
 
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
 
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
 
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
 
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
 
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
 
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
 
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
 
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
 
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
 
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
 
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
 
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
 
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
 
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
 
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
 
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
 
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
 
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
 
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
 
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
 
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
 
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
 
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
 
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
 
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
 
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
 
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
 
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
 
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
 
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
 
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
 
Include the generated interface declarations.
 
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
 
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
 
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
 
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
 
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
 
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
 
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
 
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
 
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
 
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
 
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
 
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
 
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
 
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
 
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
 
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
 
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
 
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
 
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
 
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
 
llvm::TypeSwitch< T, ResultT > TypeSwitch
 
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
 
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
 
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
 
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
 
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
 
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
 
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
 
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
 
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
 
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
 
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
 
llvm::function_ref< Fn > function_ref
 
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
 
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
 
Fold back-to-back broadcasts together.
 
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
 
Fold transpose with transpose.
 
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
 
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
 
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
 
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
 
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
 
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
 
This represents an operation in an abstracted form, suitable for use with the builder APIs.
 
void addOperands(ValueRange newOperands)
 
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
 
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
 
void addTypes(ArrayRef< Type > newTypes)
 
Region * addRegion()
Create a region that should be attached to the operation.
 
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
 
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
 
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
 
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override