MLIR  22.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/PatternMatch.h"
35 #include "mlir/IR/TypeUtilities.h"
38 
39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/STLExtras.h"
41 #include "llvm/ADT/SetOperations.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/ADT/StringSet.h"
44 #include "llvm/ADT/TypeSwitch.h"
45 #include "llvm/Support/FormatVariadic.h"
46 #include "llvm/Support/InterleavedRange.h"
47 #include "llvm/Support/LogicalResult.h"
48 #include "llvm/Support/MathExtras.h"
49 #include "llvm/Support/raw_ostream.h"
50 #include <cassert>
51 #include <optional>
52 
53 using namespace mlir;
54 using namespace mlir::linalg;
55 
56 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
58  int64_t dim) {
59  auto type = cast<ShapedType>(v.getType());
60  if (!type.isDynamicDim(dim))
61  return builder.getIndexAttr(type.getDimSize(dim));
62 
63  return getAsOpFoldResult(
65  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
66  return tensor::DimOp::create(builder, loc, v, dim);
67  })
68  .Case<MemRefType>([&](MemRefType t) -> Value {
69  return memref::DimOp::create(builder, loc, v, dim);
70  }));
71 }
72 
73 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
74 /// `source`.
75 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
76  ArrayRef<OpFoldResult> offsets,
78  ArrayRef<OpFoldResult> strides) {
79  return TypeSwitch<Type, Operation *>(source.getType())
80  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
81  return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
82  strides);
83  })
84  .Case<MemRefType>([&](MemRefType type) -> Operation * {
85  return memref::SubViewOp::create(b, loc, source, offsets, sizes,
86  strides);
87  })
88  .Default([&](Type t) -> Operation * { return nullptr; });
89 }
90 
91 //===----------------------------------------------------------------------===//
92 // Helper functions
93 //===----------------------------------------------------------------------===//
94 
96  int64_t dim) {
97  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
98  return b.createOrFold<memref::DimOp>(loc, source, dim);
99  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
100  return b.createOrFold<tensor::DimOp>(loc, source, dim);
101  llvm_unreachable("Expected MemRefType or TensorType");
102 }
103 
105  int64_t dim) {
106  auto shapedType = llvm::cast<ShapedType>(source.getType());
107  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
108  return createOrFoldDimOp(b, loc, source, dim);
109  return b.getIndexAttr(shapedType.getDimSize(dim));
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Support for named Linalg ops defined in ods-gen.
114 //===----------------------------------------------------------------------===//
115 
119 
120 /// Fills the region of a structured operation using the provided
121 /// `regionBuilder`. The method is used by both named structured ops created by
122 /// ods-gen and by manually defined C++ ops. It is called by both builders and
123 /// parsers and creates a block with arguments corresponding to the elemental
124 /// types of `inputTypes` and `outputTypes`.
125 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
126  TypeRange inputTypes, TypeRange outputTypes,
129  RegionBuilderFn regionBuilder) {
130  SmallVector<Type, 8> argTypes;
131  SmallVector<Location, 8> argLocs;
132  for (auto containers : {inputTypes, outputTypes}) {
133  for (auto t : containers) {
134  argTypes.push_back(
135  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136 
137  // TODO: Pass in a proper location here.
138  argLocs.push_back(opBuilder.getUnknownLoc());
139  }
140  }
141 
142  // RAII.
143  OpBuilder::InsertionGuard guard(opBuilder);
144  Block *body =
145  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146 
147  opBuilder.setInsertionPointToStart(body);
148  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149  regionBuilder(b, *body, attrs, emitError);
150 
151  // indexing_maps is an auto-generated method.
152 
153  // iterator_types is an auto-generated method.
154 }
155 
156 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157 /// The result types are derived automatically if `resultTensorTypes` is none.
158 /// The body of the operation is filled using `regionBuilder`. All ods-gen
159 /// created structured operations use the method to implement their builders.
161  std::optional<TypeRange> resultTensorTypes,
162  ValueRange inputs, ValueRange outputs,
163  ArrayRef<NamedAttribute> attributes,
164  RegionBuilderFn regionBuilder) {
165  // Derive the result types if needed.
166  SmallVector<Type> derivedResultTypes =
167  resultTensorTypes.value_or(TypeRange());
168  if (!resultTensorTypes)
169  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170  llvm::IsaPred<RankedTensorType>);
171 
172  state.addOperands(inputs);
173  state.addOperands(outputs);
174  state.addTypes(derivedResultTypes);
175 
176  state.addAttributes(attributes);
177  state.addAttribute(
178  "operandSegmentSizes",
179  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180  static_cast<int32_t>(outputs.size())}));
181 
182  // Create and fill the region of the structured operation.
183  Region &region = *state.addRegion();
184  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185  state.attributes.getAttrs(), /*emitError=*/{},
186  regionBuilder);
187 }
188 
189 static void buildMatmulOp(OpBuilder &b, OperationState &state,
190  std::optional<TypeRange> resultTensorTypes,
191  ValueRange inputs, ValueRange outputs,
192  ArrayRef<NamedAttribute> attributes,
193  RegionBuilderFn regionBuilder,
194  ArrayRef<AffineMap> indexingMaps) {
195  // Initialize indexingMaps attribute, for MatmulOp.
196  SmallVector<Attribute, 3> indexingMapsAttrVal;
197  indexingMapsAttrVal =
198  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
199  return AffineMapAttr::get(map);
200  });
201  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203  attributes, regionBuilder);
204 }
205 
207  std::optional<TypeRange> resultTensorTypes,
208  ValueRange inputs, ValueRange outputs,
209  ArrayRef<NamedAttribute> attributes,
210  RegionBuilderFn regionBuilder,
211  ArrayRef<AffineMap> indexingMaps) {
212  // Initialize indexingMaps attribute, for BatchMatmulOp.
213  SmallVector<Attribute, 4> indexingMapsAttrVal;
214  indexingMapsAttrVal =
215  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
216  return AffineMapAttr::get(map);
217  });
218  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220  attributes, regionBuilder);
221 }
222 
224  std::optional<TypeRange> resultTensorTypes,
225  ValueRange inputs, ValueRange outputs,
226  ArrayRef<NamedAttribute> attributes,
227  RegionBuilderFn regionBuilder,
228  ArrayRef<AffineMap> indexingMaps) {
229  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230  SmallVector<Attribute, 4> indexingMapsAttrVal;
231  indexingMapsAttrVal =
232  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
233  return AffineMapAttr::get(map);
234  });
235  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237  attributes, regionBuilder);
238 }
239 
240 /// Common parsing used for both named structured ops created by ods-gen and by
241 /// manually defined C++ ops. Does not handle regions.
242 static ParseResult
244  SmallVectorImpl<Type> &inputTypes,
245  SmallVectorImpl<Type> &outputTypes,
246  bool addOperandSegmentSizes = true) {
247  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
249  outputsOperands;
250 
251  if (succeeded(parser.parseOptionalLess())) {
252  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
253  return failure();
254  }
255  attrsLoc = parser.getCurrentLocation();
256  if (parser.parseOptionalAttrDict(result.attributes))
257  return failure();
258 
259  if (succeeded(parser.parseOptionalKeyword("ins"))) {
260  if (parser.parseLParen())
261  return failure();
262 
263  inputsOperandsLoc = parser.getCurrentLocation();
264  if (parser.parseOperandList(inputsOperands) ||
265  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
266  return failure();
267  }
268 
269  if (succeeded(parser.parseOptionalKeyword("outs"))) {
270  outputsOperandsLoc = parser.getCurrentLocation();
271  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
272  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
273  return failure();
274  }
275 
276  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
277  result.operands) ||
278  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
279  result.operands))
280  return failure();
281 
282  if (addOperandSegmentSizes) {
283  // This is a bit complex because we're trying to be backward compatible with
284  // operation syntax that mix the inherent attributes and the discardable
285  // ones in the same dictionary. If the properties are used, we append the
286  // operandSegmentSizes there directly. Otherwise we append it to the
287  // discardable attributes dictionary where it is handled by the generic
288  // Operation::create(...) method.
289  if (result.propertiesAttr) {
290  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291  attrs.append("operandSegmentSizes",
293  {static_cast<int32_t>(inputsOperands.size()),
294  static_cast<int32_t>(outputsOperands.size())}));
295  result.propertiesAttr = attrs.getDictionary(parser.getContext());
296  } else {
297  result.addAttribute("operandSegmentSizes",
299  {static_cast<int32_t>(inputsOperands.size()),
300  static_cast<int32_t>(outputsOperands.size())}));
301  }
302  }
303  if (!result.propertiesAttr) {
304  std::optional<RegisteredOperationName> info =
305  result.name.getRegisteredInfo();
306  if (info) {
307  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
308  return parser.emitError(attrsLoc)
309  << "'" << result.name.getStringRef() << "' op ";
310  })))
311  return failure();
312  }
313  }
314  return success();
315 }
316 
318  ValueRange outputs) {
319  if (!inputs.empty())
320  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321  if (!outputs.empty())
322  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323 }
324 
325 //===----------------------------------------------------------------------===//
326 // Specific parsing and printing for named structured ops created by ods-gen.
327 //===----------------------------------------------------------------------===//
328 
329 static ParseResult parseNamedStructuredOpRegion(
330  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332  RegionBuilderFn regionBuilder, SMLoc loc) {
333  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334  return parser.emitError(
335  parser.getCurrentLocation(),
336  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
337  "region expects {0} args, got {1}",
338  numRegionArgs, inputTypes.size() + outputTypes.size()));
339  }
340 
341  OpBuilder opBuilder(parser.getContext());
342  ParseResult result = success();
344  opBuilder, region, inputTypes, outputTypes, attrs,
345  [&]() {
346  result = failure();
347  return parser.emitError(loc);
348  },
349  regionBuilder);
350  return result;
351 }
352 
353 static ParseResult
355  SmallVectorImpl<Type> &resultTypes) {
356  if (parser.parseOptionalArrowTypeList(resultTypes))
357  return failure();
358  return success();
359 }
360 
361 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
362  OperationState &result,
363  unsigned numRegionArgs,
364  RegionBuilderFn regionBuilder) {
365  // TODO: Enable when ods-gen supports captures.
366  SmallVector<Type, 1> inputTypes, outputTypes;
367  SMLoc loc = parser.getCurrentLocation();
368  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
369  return failure();
370 
371  // Parse optional attributes.
372  if (parser.parseOptionalAttrDict(result.attributes))
373  return failure();
374 
375  // TODO: consider merging results parsing into region parsing.
376  // Need to wait for declarative assembly resolution to decide.
377  SmallVector<Type, 1> outputTensorsTypes;
378  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
379  return failure();
380  result.addTypes(outputTensorsTypes);
381 
382  std::unique_ptr<Region> region = std::make_unique<Region>();
383  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
384  outputTypes, result.attributes.getAttrs(),
385  regionBuilder, loc))
386  return failure();
387  result.addRegion(std::move(region));
388 
389  return success();
390 }
391 
393  TypeRange resultTypes) {
394  if (resultTypes.empty())
395  return;
396  p.printOptionalArrowTypeList(resultTypes);
397 }
398 
400  ValueRange inputs, ValueRange outputs,
401  ArrayRef<StringRef> elidedAttrs = {}) {
402  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
403 
404  // Printing is shared with generic ops, except for the region and
405  // attributes.
406  printCommonStructuredOpParts(p, inputs, outputs);
407 
408  // Results printing.
410 
411  // Region is elided.
412 }
413 
414 //===----------------------------------------------------------------------===//
415 // Region builder helper.
416 // TODO: Move this to a utility library.
417 // The public methods on this class are referenced directly from generated code.
418 // Helper build the unary, binary, and type conversion functions defined by the
419 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
420 // class.
421 //
422 // Implementations of the math functions must be polymorphic over numeric types,
423 // internally performing necessary casts. If the function application makes no
424 // sense, then the only recourse is to assert and return nullptr. This can be
425 // extended later if it becomes possible to fail construction of the region. The
426 // invariant should be enforced at a higher level.
427 //
428 // TODO: These helpers are currently type polymorphic over the class of integer
429 // and floating point types, but they will not internally cast within bit
430 // widths of a class (mixed precision such as i8->i32) or across classes
431 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
432 // to be handled with care and work is being considered to extend the op
433 // language to make such cases explicit. In the mean-time, violating this will
434 // fail verification, which is deemed acceptable.
435 //===----------------------------------------------------------------------===//
436 
437 namespace {
438 
439 class RegionBuilderHelper {
440 public:
441  RegionBuilderHelper(OpBuilder &builder, Block &block)
442  : builder(builder), block(block) {}
443 
444  // Build the unary functions defined by OpDSL.
445  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
447  if (!isFloatingPoint(arg)) {
448  if (emitError) {
449  emitError() << "unsupported non numeric type";
450  return nullptr;
451  }
452  llvm_unreachable("unsupported non numeric type");
453  }
454  OpBuilder::InsertionGuard g(builder);
455  builder.setInsertionPointToEnd(&block);
456  switch (unaryFn) {
457  case UnaryFn::exp:
458  return math::ExpOp::create(builder, arg.getLoc(), arg);
459  case UnaryFn::log:
460  return math::LogOp::create(builder, arg.getLoc(), arg);
461  case UnaryFn::abs:
462  return math::AbsFOp::create(builder, arg.getLoc(), arg);
463  case UnaryFn::ceil:
464  return math::CeilOp::create(builder, arg.getLoc(), arg);
465  case UnaryFn::floor:
466  return math::FloorOp::create(builder, arg.getLoc(), arg);
467  case UnaryFn::negf:
468  return arith::NegFOp::create(builder, arg.getLoc(), arg);
469  case UnaryFn::reciprocal: {
470  Attribute oneAttr = builder.getOneAttr(arg.getType());
471  auto one = arith::ConstantOp::create(builder, arg.getLoc(),
472  ::cast<TypedAttr>(oneAttr));
473  return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
474  }
475  case UnaryFn::round:
476  return math::RoundOp::create(builder, arg.getLoc(), arg);
477  case UnaryFn::sqrt:
478  return math::SqrtOp::create(builder, arg.getLoc(), arg);
479  case UnaryFn::rsqrt:
480  return math::RsqrtOp::create(builder, arg.getLoc(), arg);
481  case UnaryFn::square:
482  return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
483  case UnaryFn::tanh:
484  return math::TanhOp::create(builder, arg.getLoc(), arg);
485  case UnaryFn::erf:
486  return math::ErfOp::create(builder, arg.getLoc(), arg);
487  }
488  if (emitError) {
489  emitError() << "unsupported unary function";
490  return nullptr;
491  }
492  llvm_unreachable("unsupported unary function");
493  }
494 
495  // Build the binary functions defined by OpDSL.
496  // If emitError is provided, an error will be emitted if the operation is not
497  // supported and a nullptr will be returned, otherwise an assertion will be
498  // raised.
499  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
501  bool allComplex = isComplex(arg0) && isComplex(arg1);
502  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
503  bool allInteger = isInteger(arg0) && isInteger(arg1);
504  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
505  arg1.getType().getIntOrFloatBitWidth() == 1;
506  if (!allComplex && !allFloatingPoint && !allInteger) {
507  if (emitError) {
508  emitError()
509  << "Cannot build binary Linalg operation: expects allComplex, "
510  "allFloatingPoint, or allInteger, got "
511  << arg0.getType() << " and " << arg1.getType();
512  return nullptr;
513  }
514  llvm_unreachable("unsupported non numeric type");
515  }
516  OpBuilder::InsertionGuard g(builder);
517  builder.setInsertionPointToEnd(&block);
518  switch (binaryFn) {
519  case BinaryFn::add:
520  if (allComplex)
521  return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
522  if (allFloatingPoint)
523  return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
524  if (allBool)
525  return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
526  return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
527  case BinaryFn::sub:
528  if (allComplex)
529  return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
530  if (allFloatingPoint)
531  return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
532  if (allBool) {
533  if (emitError) {
534  emitError() << "unsupported operation: sub with bools";
535  return nullptr;
536  }
537  llvm_unreachable("unsupported operation: sub with bools");
538  }
539  return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
540  case BinaryFn::mul:
541  if (allComplex)
542  return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
543  if (allFloatingPoint)
544  return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
545  if (allBool)
546  return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
547  return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
548  case BinaryFn::div:
549  if (allComplex)
550  return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
551  if (allFloatingPoint)
552  return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
553  if (allBool) {
554  if (emitError) {
555  emitError() << "unsupported operation: div with bools";
556  return nullptr;
557  }
558  llvm_unreachable("unsupported operation: div with bools");
559  }
560  return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
561  case BinaryFn::div_unsigned:
562  if (!allInteger || allBool) {
563  if (emitError) {
564  emitError() << "unsupported operation: unsigned div not on uint";
565  return nullptr;
566  }
567  llvm_unreachable("unsupported operation: unsigned div not on uint");
568  }
569  return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
570  case BinaryFn::max_signed:
571  assert(!allComplex);
572  if (allFloatingPoint)
573  return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
574  return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
575  case BinaryFn::min_signed:
576  assert(!allComplex);
577  if (allFloatingPoint)
578  return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
579  return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
580  case BinaryFn::max_unsigned:
581  assert(!allComplex);
582  if (allFloatingPoint)
583  return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
584  return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
585  case BinaryFn::min_unsigned:
586  assert(!allComplex);
587  if (allFloatingPoint)
588  return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
589  return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
590  case BinaryFn::powf:
591  assert(allFloatingPoint);
592  return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
593  }
594  if (emitError) {
595  emitError() << "unsupported binary function";
596  return nullptr;
597  }
598  llvm_unreachable("unsupported binary function");
599  }
600 
601  // Build the ternary functions defined by OpDSL.
602  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
604  bool headBool =
605  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
606  bool tailFloatingPoint =
607  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
608  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
609  OpBuilder::InsertionGuard g(builder);
610  builder.setInsertionPointToEnd(&block);
611  switch (ternaryFn) {
612  case TernaryFn::select:
613  if (!headBool && !(tailFloatingPoint || tailInteger))
614  llvm_unreachable("unsupported non numeric type");
615  return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
616  }
617  if (emitError) {
618  emitError() << "unsupported ternary function";
619  return nullptr;
620  }
621  llvm_unreachable("unsupported ternary function");
622  }
623 
624  // Build the type functions defined by OpDSL.
625  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
627  switch (typeFn) {
628  case TypeFn::cast_signed:
629  return cast(toType, operand, false);
630  case TypeFn::cast_unsigned:
631  return cast(toType, operand, true);
632  }
633  if (emitError) {
634  emitError() << "unsupported type conversion function";
635  return nullptr;
636  }
637  llvm_unreachable("unsupported type conversion function");
638  }
639 
640  void yieldOutputs(ValueRange values) {
641  OpBuilder::InsertionGuard g(builder);
642  builder.setInsertionPointToEnd(&block);
643  Location loc = builder.getUnknownLoc();
644  YieldOp::create(builder, loc, values);
645  }
646 
647  Value constant(const std::string &value) {
648  OpBuilder::InsertionGuard g(builder);
649  builder.setInsertionPointToEnd(&block);
650  Location loc = builder.getUnknownLoc();
651  Attribute valueAttr = parseAttribute(value, builder.getContext());
652  return arith::ConstantOp::create(builder, loc,
653  ::cast<TypedAttr>(valueAttr));
654  }
655 
656  Value index(int64_t dim) {
657  OpBuilder::InsertionGuard g(builder);
658  builder.setInsertionPointToEnd(&block);
659  return IndexOp::create(builder, builder.getUnknownLoc(), dim);
660  }
661 
662  Type getIntegerType(unsigned width) {
663  return IntegerType::get(builder.getContext(), width);
664  }
665 
666  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
667  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
668 
669 private:
670  // Generates operations to cast the given operand to a specified type.
671  // If the cast cannot be performed, a warning will be issued and the
672  // operand returned as-is (which will presumably yield a verification
673  // issue downstream).
674  Value cast(Type toType, Value operand, bool isUnsignedCast) {
675  OpBuilder::InsertionGuard g(builder);
676  builder.setInsertionPointToEnd(&block);
677  auto loc = operand.getLoc();
678  if (isa<UnknownLoc>(loc)) {
679  if (operand.getDefiningOp())
680  loc = operand.getDefiningOp()->getLoc();
681  else if (operand.getParentBlock() &&
682  operand.getParentBlock()->getParentOp())
683  loc = operand.getParentBlock()->getParentOp()->getLoc();
684  }
685  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
686  }
687 
688  bool isComplex(Value value) {
689  return llvm::isa<ComplexType>(value.getType());
690  }
691  bool isFloatingPoint(Value value) {
692  return llvm::isa<FloatType>(value.getType());
693  }
694  bool isInteger(Value value) {
695  return llvm::isa<IntegerType>(value.getType());
696  }
697 
698  OpBuilder &builder;
699  Block &block;
700 };
701 
702 } // namespace
703 
704 //===----------------------------------------------------------------------===//
705 // CopyOp
706 //===----------------------------------------------------------------------===//
707 
708 namespace {
709 
710 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
712  LogicalResult matchAndRewrite(CopyOp copyOp,
713  PatternRewriter &rewriter) const override {
714  if (copyOp.getInputs() != copyOp.getOutputs())
715  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
716  if (copyOp.hasPureBufferSemantics())
717  rewriter.eraseOp(copyOp);
718  else
719  rewriter.replaceOp(copyOp, copyOp.getInputs());
720 
721  return success();
722  }
723 };
724 
725 } // namespace
726 
727 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
728  MLIRContext *context) {
729  results.add<EraseSelfCopy>(context);
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // FillOp
734 //===----------------------------------------------------------------------===//
735 
736 namespace {
737 
738 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
739 ///
740 /// For such op chains, we can create new linalg.fill ops with the result
741 /// type of the tensor.expand/collapse_shape op.
742 template <typename TensorReshapeOp>
743 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
745  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
746  PatternRewriter &rewriter) const override {
747  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
748  if (!oldFill)
749  return failure();
750 
751  Location loc = oldFill.getLoc();
752  TensorReshapeOp newInit;
753  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
754 
755  newInit = TensorReshapeOp::create(
756  rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
757  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
758  reshapeOp.getStaticOutputShape());
759  } else {
760  newInit = TensorReshapeOp::create(
761  rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
762  reshapeOp.getReassociation());
763  }
764  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
765  ValueRange{newInit});
766  return success();
767  }
768 };
769 
770 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
771 /// filling value are the same.
772 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
774 
775  LogicalResult matchAndRewrite(tensor::PadOp padOp,
776  PatternRewriter &rewriter) const override {
777  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
778  if (!fillOp)
779  return failure();
780 
781  // We can only fold if the padding value is the same as the original
782  // filling value.
783  Value padValue = padOp.getConstantPaddingValue();
784  if (!padValue || fillOp.value() != padValue)
785  return failure();
786 
787  ReifiedRankedShapedTypeDims reifiedShape;
788  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
789  return rewriter.notifyMatchFailure(
790  padOp, "failed to reify tensor.pad op result shape");
791 
792  auto emptyTensor =
793  tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
794  padOp.getResultType().getElementType());
795  Value replacement =
796  FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
797  ValueRange{emptyTensor})
798  .getResult(0);
799  if (replacement.getType() != padOp.getResultType()) {
800  replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
801  padOp.getResultType(), replacement);
802  }
803  rewriter.replaceOp(padOp, replacement);
804  return success();
805  }
806 };
807 
808 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
809 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
810 /// filling value are the same.
811 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
813 
814  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
815  PatternRewriter &rewriter) const override {
816  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
817  if (!srcPadOp)
818  return failure();
819 
820  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
821  return failure();
822 
823  // Walk back the tensor.insert_slice chain and find the first destination
824  // value at the start of the chain.
825  Value firstDest = insertOp.getDest();
826  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
827  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
828  return failure();
829 
830  // Make sure the range of values accessed are disjoint. Without this, we
831  // cannot fold tensor.pad away.
832  bool disjoint = false;
833  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
834  // If the dimension has dynamic offset/size, we cannot guarantee
835  // disjoint. So just skip it.
836  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
837  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
838  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
839  continue;
840 
841  // Get the range start and end, inclusively for both.
842  int64_t prevStart = prevOp.getStaticOffset(i);
843  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
844  prevOp.getStaticStride(i);
845  int64_t nextStart = insertOp.getStaticOffset(i);
846  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
847  insertOp.getStaticStride(i);
848  if (prevEnd < nextStart || nextEnd < prevStart) {
849  disjoint = true;
850  break;
851  }
852  }
853 
854  if (!disjoint)
855  break;
856  firstDest = prevOp.getDest();
857  }
858 
859  // Check whether the first destination is a fill op. For overlapped cases,
860  // this also cannot be true.
861  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
862  if (!dstFillOp)
863  return failure();
864 
865  // We can only fold if the padding value is the same as the original
866  // filling value.
867  Value padValue = srcPadOp.getConstantPaddingValue();
868  if (!padValue || dstFillOp.value() != padValue)
869  return failure();
870 
871  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
872  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
873 
874  Location loc = insertOp.getLoc();
875  MLIRContext *context = getContext();
876 
877  AffineExpr sym0, sym1;
878  bindSymbols(context, sym0, sym1);
879  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
880 
881  // Calculate the new offsets for the insert. It should be the old offsets
882  // plus low padding sizes.
883  SmallVector<OpFoldResult, 4> newOffsets;
884  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
885  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
886  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
887  }
888 
889  RankedTensorType srcPadType = srcPadOp.getSourceType();
891  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
892  if (srcPadType.isDynamicDim(i)) {
893  newSizes.push_back(
894  tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
895  .getResult());
896  } else {
897  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
898  }
899  }
900 
901  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
902  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
903  newSizes, insertOp.getMixedStrides());
904  return success();
905  }
906 };
907 
908 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
909 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
910 public:
912 
913  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
914  PatternRewriter &rewriter) const override {
915  // See if tensor input of tensor.extract op is the result of a linalg.fill
916  // op.
917  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
918  if (!fillOp)
919  return failure();
920 
921  // Get scalar input operand of linalg.fill op.
922  Value extractedScalar = fillOp.getInputs()[0];
923 
924  // Replace tensor.extract op with scalar value used to fill the tensor.
925  rewriter.replaceOp(extractOp, extractedScalar);
926  return success();
927  }
928 };
929 
930 /// Folds pack(fill) into a single fill op if
931 /// 1. The pack op does not have padding value, or
932 /// 2. The filled value and padding value are the same.
933 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
934  linalg::PackOp packOp) {
935  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
936  if (!fillOp)
937  return failure();
938 
939  if (auto paddingValue = packOp.getPaddingValue())
940  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
941  return failure();
942 
943  Value packOpDest = packOp.getDest();
944  if (!packOpDest.hasOneUse())
945  return failure();
946 
947  return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
948  packOp.getDest());
949 }
950 
951 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
952 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
953 public:
954  FoldFillWithPack(MLIRContext *context)
955  : OpRewritePattern<linalg::PackOp>(context) {}
956 
957  LogicalResult matchAndRewrite(linalg::PackOp packOp,
958  PatternRewriter &rewriter) const override {
959  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
960  if (failed(fillOp))
961  return failure();
962  rewriter.replaceOp(packOp, fillOp.value().result());
963  return success();
964  }
965 };
966 
967 /// Fold fill with copy.
968 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
970 
971  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
972  PatternRewriter &rewriter) const override {
973  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
974  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
975  fillOp.getInputs(),
976  copyOp.getOutputs());
977  return success();
978  }
979  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
980  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
981  fillOp.getOutputs());
982  return success();
983  }
984  return failure();
985  }
986 };
987 
988 /// Fold fill with transpose.
989 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
991 
992  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
993  PatternRewriter &rewriter) const override {
994  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
995  rewriter.replaceOpWithNewOp<FillOp>(
996  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
997  transposeOp.getDpsInitOperand(0)->get());
998  return success();
999  }
1000  return failure();
1001  }
1002 };
1003 
1004 /// Fold a concat with all elements being fills of the same value
1005 /// into a fill of the concat result shape.
1006 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1008 
1009  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1010  PatternRewriter &rewriter) const override {
1011  auto concatOperands = concatOp.getInputs();
1012  if (concatOperands.empty()) {
1013  return failure();
1014  }
1015 
1016  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1017  if (!firstFillOp) {
1018  return failure();
1019  }
1020  // Prefetch the fill value.
1021  OpFoldResult firstFillVal =
1022  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1023  // Collect all the outs values for the fill operations.
1024  SmallVector<Value> allOuts;
1025  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1026 
1027  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1028  auto fillOp = v.getDefiningOp<linalg::FillOp>();
1029  if (!fillOp) {
1030  return false;
1031  }
1032 
1033  OpFoldResult fillVal =
1034  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1035  if (fillVal != firstFillVal)
1036  return false;
1037 
1038  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1039  return true;
1040  };
1041  if (!llvm::all_of(concatOperands.drop_front(),
1042  isDefinedByCompatibleFillOp)) {
1043  return rewriter.notifyMatchFailure(
1044  concatOp, "not all operands are defined by a compatible fill op");
1045  }
1046 
1047  Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1048  concatOp.getDim(), allOuts);
1049  rewriter.replaceOpWithNewOp<linalg::FillOp>(
1050  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1051  return success();
1052  }
1053 };
1054 
1055 } // namespace
1056 
1057 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1058  MLIRContext *context) {
1059  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1060  FoldFillWithPack, FoldFillWithPad,
1061  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1062  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1063  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1064 }
1065 
1066 //===----------------------------------------------------------------------===//
1067 // GenericOp
1068 //===----------------------------------------------------------------------===//
1069 
1071  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1072  ValueRange outputs,
1073  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1074  SmallVector<Type, 4> blockArgTypes;
1075  SmallVector<Location, 4> blockArgLocs;
1076  for (ValueRange container : {inputs, outputs}) {
1077  for (Value v : container) {
1078  Type t = v.getType();
1079  blockArgTypes.push_back(
1080  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1081  blockArgLocs.push_back(v.getLoc());
1082  }
1083  }
1084 
1085  OpBuilder::InsertionGuard guard(builder);
1086  Block *bodyBlock =
1087  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1088  bodyBuild(builder, loc, bodyBlock->getArguments());
1089 }
1090 
1091 void GenericOp::getAsmBlockArgumentNames(Region &region,
1092  OpAsmSetValueNameFn setNameFn) {
1093  for (Value v : getRegionInputArgs())
1094  setNameFn(v, "in");
1095  for (Value v : getRegionOutputArgs())
1096  setNameFn(v, "out");
1097 }
1098 
1099 void GenericOp::build(
1100  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1101  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1102  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1103  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1104  ArrayRef<NamedAttribute> attributes) {
1105  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1106  iteratorTypes, doc, libraryCall);
1107  result.addAttributes(attributes);
1108  if (bodyBuild)
1109  buildGenericRegion(builder, result.location, *result.regions.front(),
1110  inputs, outputs, bodyBuild);
1111 }
1112 
1113 void GenericOp::build(
1114  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1115  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1116  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1117  StringRef libraryCall,
1118  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1119  ArrayRef<NamedAttribute> attributes) {
1120  build(builder, result, resultTensorTypes, inputs, outputs,
1121  builder.getAffineMapArrayAttr(indexingMaps),
1122  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1123  iteratorTypes,
1124  [&](utils::IteratorType iter) -> mlir::Attribute {
1125  return IteratorTypeAttr::get(builder.getContext(), iter);
1126  }))),
1127  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1128  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1129  bodyBuild, attributes);
1130 }
1131 
1132 void GenericOp::build(
1133  OpBuilder &builder, OperationState &result, ValueRange inputs,
1134  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1135  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1136  StringRef libraryCall,
1137  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1138  ArrayRef<NamedAttribute> attributes) {
1139  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1140  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1141 }
1142 
1143 void GenericOp::build(
1144  OpBuilder &builder, OperationState &result, ValueRange inputs,
1145  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1146  ArrayRef<utils::IteratorType> iteratorTypes,
1147  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1148  ArrayRef<NamedAttribute> attributes) {
1149  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1150  /*doc=*/"",
1151  /*libraryCall=*/"", bodyBuild, attributes);
1152 }
1153 
1154 void GenericOp::build(
1155  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1156  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1157  ArrayRef<utils::IteratorType> iteratorTypes,
1158  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1159  ArrayRef<NamedAttribute> attributes) {
1160  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1161  iteratorTypes,
1162  /*doc=*/"",
1163  /*libraryCall=*/"", bodyBuild, attributes);
1164 }
1165 
1166 void GenericOp::print(OpAsmPrinter &p) {
1167  p << " ";
1168 
1169  // Print extra attributes.
1170  auto genericAttrNames = linalgTraitAttrNames();
1171 
1172  llvm::StringSet<> genericAttrNamesSet;
1173  genericAttrNamesSet.insert_range(genericAttrNames);
1174  SmallVector<NamedAttribute, 8> genericAttrs;
1175  for (auto attr : (*this)->getAttrs()) {
1176  if (attr.getName() == getIteratorTypesAttrName()) {
1177  auto iteratorTypes =
1178  llvm::cast<ArrayAttr>(attr.getValue())
1179  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1180  // Convert IteratorType enums into the string representation. This is
1181  // needed, because tests still use the old format when 'iterator_types'
1182  // attribute is represented as an array of strings.
1183  // TODO: Remove this conversion once tests are fixed.
1184  SmallVector<Attribute> iteratorTypeNames =
1185  llvm::to_vector(llvm::map_range(
1186  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1187  return StringAttr::get(getContext(), stringifyIteratorType(t));
1188  }));
1189 
1190  genericAttrs.emplace_back(
1191  getIteratorTypesAttrName(),
1192  ArrayAttr::get(getContext(), iteratorTypeNames));
1193  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1194  genericAttrs.push_back(attr);
1195  }
1196  }
1197  if (!genericAttrs.empty()) {
1198  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1199  p << genericDictAttr;
1200  }
1201 
1202  // Printing is shared with named ops, except for the region and attributes
1203  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1204 
1205  genericAttrNames.push_back("operandSegmentSizes");
1206  genericAttrNamesSet.insert(genericAttrNames.back());
1207 
1208  bool hasExtraAttrs = false;
1209  for (NamedAttribute n : (*this)->getAttrs()) {
1210  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1211  break;
1212  }
1213  if (hasExtraAttrs) {
1214  p << " attrs = ";
1215  p.printOptionalAttrDict((*this)->getAttrs(),
1216  /*elidedAttrs=*/genericAttrNames);
1217  }
1218 
1219  // Print region.
1220  if (!getRegion().empty()) {
1221  p << ' ';
1222  p.printRegion(getRegion());
1223  }
1224 
1225  // Print results.
1226  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1227 }
1228 
1229 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1230  DictionaryAttr dictAttr;
1231  // Parse the core linalg traits that must check into a dictAttr.
1232  // The name is unimportant as we will overwrite result.attributes.
1233  // The core linalg traits must contain the information necessary to pass the
1234  // verifier.
1235  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1236  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1237  return failure();
1238  result.attributes.assign(dictAttr.getValue().begin(),
1239  dictAttr.getValue().end());
1240 
1241  // Convert array of string into an array of IteratorType enums. This is
1242  // needed, because tests still use the old format when 'iterator_types'
1243  // attribute is represented as an array of strings.
1244  // TODO: Remove this conversion once tests are fixed.
1245  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1246  result.attributes.get(getIteratorTypesAttrName(result.name)));
1247  if (!iteratorTypes) {
1248  return parser.emitError(attributeLocation)
1249  << "expected " << getIteratorTypesAttrName(result.name)
1250  << " array attribute";
1251  }
1252 
1253  SmallVector<Attribute> iteratorTypeAttrs;
1254 
1255  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1256  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1257  if (!maybeIteratorType.has_value())
1258  return parser.emitError(parser.getCurrentLocation())
1259  << "unexpected iterator_type (" << s << ")";
1260 
1261  iteratorTypeAttrs.push_back(
1262  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1263  }
1264  result.attributes.set(getIteratorTypesAttrName(result.name),
1265  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1266 
1267  // Parsing is shared with named ops, except for the region.
1268  SmallVector<Type, 1> inputTypes, outputTypes;
1269  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1270  return failure();
1271 
1272  // Optional attributes may be added.
1273  if (succeeded(parser.parseOptionalKeyword("attrs")))
1274  if (failed(parser.parseEqual()) ||
1275  failed(parser.parseOptionalAttrDict(result.attributes)))
1276  return failure();
1277 
1278  std::unique_ptr<Region> region = std::make_unique<Region>();
1279  if (parser.parseRegion(*region, {}))
1280  return failure();
1281  result.addRegion(std::move(region));
1282 
1283  // Generic ops may specify that a subset of its outputs are tensors. Such
1284  // outputs are specified in the result type.
1285  // TODO: may need to move output parsing before region parsing.
1286  // Need to wait for declarative assembly resolution to decide.
1287  SmallVector<Type, 1> outputTensorsTypes;
1288  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1289  return failure();
1290  result.addTypes(outputTensorsTypes);
1291 
1292  return success();
1293 }
1294 
1297  &effects,
1298  LinalgOp linalgOp) {
1299  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1300  if (!llvm::isa<MemRefType>(operand.getType()))
1301  continue;
1302  effects.emplace_back(
1303  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1304  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1305  }
1306 
1307  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1308  if (!llvm::isa<MemRefType>(operand.get().getType()))
1309  continue;
1310  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1311  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1312  /*effectOnFullRegion=*/true,
1314  }
1315  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1316  /*effectOnFullRegion=*/true,
1318  }
1319 }
1320 
1321 void GenericOp::getEffects(
1323  &effects) {
1324  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1325 }
1326 
1328 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1329  // Operands with value semantics are speculatable, while operands with memory
1330  // semantics are not.
1331  if (!linalgOp.hasPureTensorSemantics())
1333  // The body of the op can still have speculation in its region.
1335 }
1336 
1337 Speculation::Speculatability GenericOp::getSpeculatability() {
1338  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1339 }
1340 
1341 LogicalResult GenericOp::verify() { return success(); }
1342 
1343 namespace {
1344 
1345 /// Remove linalg operations that are just copying the values from inputs to
1346 /// results. In the memref case, the operation must be copying to and from the
1347 /// same value. Requirements are:
1348 /// 1) All iterator types are parallel
1349 /// 2) The body contains just a yield operation with the yielded values being
1350 /// the arguments corresponding to the operands.
1351 template <typename OpTy>
1352 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1354 
1355  LogicalResult matchAndRewrite(OpTy linalgOp,
1356  PatternRewriter &rewriter) const override {
1357  // All indexing maps must be equal. It follows that they are permutations.
1358  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1359  return failure();
1360 
1361  // Check that the body of the linalg operation is just a linalg.yield
1362  // operation.
1363  Block &body = linalgOp->getRegion(0).front();
1364  if (!llvm::hasSingleElement(body))
1365  return failure();
1366  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1367  if (!yieldOp)
1368  return failure();
1369 
1370  // In the buffer case, we need to check exact buffer equality.
1371  if (linalgOp.hasPureBufferSemantics()) {
1372  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1373  linalgOp.getDpsInputOperand(0)->get() !=
1374  linalgOp.getDpsInitOperand(0)->get()) {
1375  return rewriter.notifyMatchFailure(
1376  linalgOp, "expected single input and output to be the same value");
1377  }
1378 
1379  auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1380  if (!yieldArg || yieldArg.getOwner() != &body) {
1381  return rewriter.notifyMatchFailure(linalgOp,
1382  "cannot fold fill-like op");
1383  }
1384 
1385  rewriter.eraseOp(linalgOp);
1386  return success();
1387  }
1388 
1389  if (!linalgOp.hasPureTensorSemantics()) {
1390  return rewriter.notifyMatchFailure(
1391  linalgOp, "mixed semantics is not supported yet");
1392  }
1393 
1394  // Get the argument number of the returned values. That is the operand
1395  // number to use for replacing uses of this operation.
1396  SmallVector<Value> returnedArgs;
1397  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1398  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1399  if (!yieldArg || yieldArg.getOwner() != &body)
1400  return failure();
1401  unsigned argumentNumber = yieldArg.getArgNumber();
1402  Value returnedArg = linalgOp->getOperand(argumentNumber);
1403  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1404  // The input can have a different type than the result, e.g. a dynamic
1405  // input dimension can be turned into a static output dimension.
1406  Type returnType = returnedArg.getType();
1407  if (returnType != resultType) {
1408  // Distinguish between sparse conversion or dense tensor casting.
1409  // TODO: unify the two ops?
1410  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1412  returnedArg = sparse_tensor::ConvertOp::create(
1413  rewriter, linalgOp.getLoc(), resultType, returnedArg);
1414  else {
1415  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1416  resultType))
1417  return failure();
1418  returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1419  resultType, returnedArg);
1420  }
1421  }
1422  returnedArgs.push_back(returnedArg);
1423  }
1424 
1425  if (returnedArgs.size() != linalgOp->getNumResults())
1426  return failure();
1427  rewriter.replaceOp(linalgOp, returnedArgs);
1428  return success();
1429  }
1430 };
1431 
1432 } // namespace
1433 
1434 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1435  MLIRContext *context) {
1436  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1437 }
1438 
1439 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1440  return memref::foldMemRefCast(*this);
1441 }
1442 
1443 //===----------------------------------------------------------------------===//
1444 // MapOp
1445 //===----------------------------------------------------------------------===//
1446 
1447 static ParseResult parseDstStyleOp(
1448  OpAsmParser &parser, OperationState &result,
1449  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1450  nullptr) {
1451  // Parse `ins` and `outs`.
1452  SmallVector<Type, 4> inputTypes, outputTypes;
1453  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1454  /*addOperandSegmentSizes=*/false))
1455  return failure();
1456 
1457  // Add result types.
1458  for (Type outputType : outputTypes) {
1459  if (llvm::isa<RankedTensorType>(outputType))
1460  result.addTypes(outputType);
1461  }
1462 
1463  // Parse required attributes.
1464  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1465  return failure();
1466 
1467  // Parse optional attributes.
1468  if (parser.parseOptionalAttrDict(result.attributes))
1469  return failure();
1470  return success();
1471 }
1472 
1473 void MapOp::getAsmBlockArgumentNames(Region &region,
1474  OpAsmSetValueNameFn setNameFn) {
1475  for (Value v : getRegionInputArgs())
1476  setNameFn(v, "in");
1477 }
1478 
1479 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1480  if (!getResults().empty())
1481  setNameFn(getResults().front(), "mapped");
1482 }
1483 
1484 void MapOp::build(
1485  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1486  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1487  ArrayRef<NamedAttribute> attributes) {
1488  build(builder, result, TypeRange{}, inputs, init);
1489  result.addAttributes(attributes);
1490 
1491  // Add output types for `RankedTensorType` output arguments.
1492  Type initType = init.getType();
1493  if (llvm::isa<RankedTensorType>(initType))
1494  result.addTypes(initType);
1495 
1496  if (bodyBuild)
1497  buildGenericRegion(builder, result.location, *result.regions.front(),
1498  inputs, /*outputs=*/{}, bodyBuild);
1499 }
1500 
1501 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1502  const OperationName &payloadOpName,
1503  const NamedAttrList &payloadOpAttrs,
1504  ArrayRef<Value> operands,
1505  bool initFirst = false) {
1506  OpBuilder b(parser.getContext());
1507  Region *body = result.addRegion();
1508  Block &block = body->emplaceBlock();
1509  b.setInsertionPointToStart(&block);
1510  for (auto &operand : operands) {
1511  block.addArgument(
1512  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1513  b.getUnknownLoc());
1514  }
1515  SmallVector<Value> payloadOpOperands;
1516  // If initFirst flag is enabled, we consider init as the first position of
1517  // payload operands.
1518  if (initFirst) {
1519  payloadOpOperands.push_back(block.getArguments().back());
1520  for (const auto &arg : block.getArguments().drop_back())
1521  payloadOpOperands.push_back(arg);
1522  } else {
1523  payloadOpOperands = {block.getArguments().begin(),
1524  block.getArguments().end()};
1525  }
1526 
1527  Operation *payloadOp = b.create(
1528  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1529  payloadOpOperands,
1530  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1531  .getElementType()},
1532  payloadOpAttrs);
1533  YieldOp::create(b, result.location, payloadOp->getResults());
1534 }
1535 
1536 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1537  std::optional<OperationName> payloadOpName;
1538  NamedAttrList payloadOpAttrs;
1539  if (succeeded(parser.parseOptionalLBrace())) {
1540  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1541  if (failed(operationName))
1542  return failure();
1543  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1544  return failure();
1545  payloadOpName = operationName.value();
1546  if (parser.parseRBrace())
1547  return failure();
1548  }
1549 
1550  if (parseDstStyleOp(parser, result))
1551  return failure();
1552 
1553  if (payloadOpName.has_value()) {
1554  if (!result.operands.empty())
1555  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1556  payloadOpAttrs,
1557  ArrayRef(result.operands).drop_back());
1558  else
1559  result.addRegion();
1560  } else {
1562  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1563  /*allowType=*/true, /*allowAttrs=*/true)) {
1564  return failure();
1565  }
1566  Region *body = result.addRegion();
1567  if (parser.parseRegion(*body, regionArgs))
1568  return failure();
1569  }
1570  return success();
1571 }
1572 
1573 static bool canUseShortForm(Block *body, bool initFirst = false) {
1574  // Check if the body can be printed in short form. The following 4 conditions
1575  // must be satisfied:
1576 
1577  // 1) The body must contain exactly 2 operations: the payload op and a yield.
1578  if (body->getOperations().size() != 2)
1579  return false;
1580  Operation &payload = body->getOperations().front();
1581 
1582  // 2) The payload op must have the same number of operands as the number of
1583  // block arguments.
1584  if (payload.getNumOperands() == 0 ||
1585  payload.getNumOperands() != body->getNumArguments())
1586  return false;
1587 
1588  // 3) If `initFirst` is true (e.g., for reduction ops), the init block
1589  // must be the first operand of the payload op, otherwise, the operands
1590  // must match the block arguments in order.
1591  if (initFirst) {
1592  // check init
1593  if (payload.getOperands().back() != body->getArgument(0))
1594  return false;
1595  // check rest
1596  for (const auto &[operand, bbArg] :
1597  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1598  if (bbArg != operand)
1599  return false;
1600  }
1601  } else {
1602  for (const auto &[operand, bbArg] :
1603  llvm::zip(payload.getOperands(), body->getArguments())) {
1604  if (bbArg != operand)
1605  return false;
1606  }
1607  }
1608 
1609  // 4) The `yield` operand must be the result of the payload op.
1610  auto yieldOp = cast<YieldOp>(body->getTerminator());
1611  return yieldOp.getNumOperands() == 1 &&
1612  yieldOp.getOperand(0).getDefiningOp() &&
1613  yieldOp.getOperand(0).getDefiningOp() == &payload;
1614 }
1615 
1616 static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1617  SmallVector<StringRef> elidedAttrs;
1618  std::string attrToElide;
1619  p << " { " << payloadOp->getName().getStringRef();
1620  for (const auto &attr : payloadOp->getAttrs()) {
1621  auto fastAttr =
1622  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1623  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1624  attrToElide = attr.getName().str();
1625  elidedAttrs.push_back(attrToElide);
1626  break;
1627  }
1628  }
1629  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1630  p << " }";
1631 }
1632 
1633 void MapOp::print(OpAsmPrinter &p) {
1634  Block *mapper = getBody();
1635  bool useShortForm = canUseShortForm(mapper);
1636  if (useShortForm) {
1637  printShortForm(p, &mapper->getOperations().front());
1638  }
1639 
1640  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1641  p.printOptionalAttrDict((*this)->getAttrs());
1642 
1643  if (!useShortForm) {
1644  // Print region if the payload op was not detected.
1645  p.increaseIndent();
1646  p.printNewline();
1647  p << "(";
1648  llvm::interleaveComma(mapper->getArguments(), p,
1649  [&](auto arg) { p.printRegionArgument(arg); });
1650  p << ") ";
1651 
1652  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1653  p.decreaseIndent();
1654  }
1655 }
1656 
1657 LogicalResult MapOp::verify() {
1658  auto *bodyBlock = getBody();
1659  auto blockArgs = bodyBlock->getArguments();
1660 
1661  // Checks if the number of `inputs` match the arity of the `mapper` region.
1662  if (getInputs().size() != blockArgs.size())
1663  return emitOpError() << "expects number of operands to match the arity of "
1664  "mapper, but got: "
1665  << getInputs().size() << " and " << blockArgs.size();
1666 
1667  // The parameters of mapper should all match the element type of inputs.
1668  for (const auto &[bbArgType, inputArg] :
1669  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1670  auto inputElemType =
1671  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1672  if (bbArgType != inputElemType) {
1673  return emitOpError() << "expected element type of input " << inputElemType
1674  << " to match bbArg type " << bbArgType;
1675  }
1676  }
1677 
1678  // The shape of each input must match the shape of the output.
1679  auto outputShape = getInit().getType().getShape();
1680  for (Type inputArgType : TypeRange{getInputs()}) {
1681  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1682  if (inputElemShape != outputShape) {
1683  return emitOpError() << "expected shape of input (" << inputElemShape
1684  << ") to match shape of output (" << outputShape
1685  << ")";
1686  }
1687  }
1688 
1689  return success();
1690 }
1691 
1692 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1693  int64_t rank = getInit().getType().getRank();
1694  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1695 }
1696 
1697 ArrayAttr MapOp::getIndexingMaps() {
1698  Builder builder(getContext());
1699  int64_t rank = getInit().getType().getRank();
1700  int64_t numIndexingMaps = getOperands().size();
1702  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1703 }
1704 
1705 void MapOp::getEffects(
1707  &effects) {
1708  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1709 }
1710 
1711 Speculation::Speculatability MapOp::getSpeculatability() {
1712  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1713 }
1714 
1715 //===----------------------------------------------------------------------===//
1716 // ReduceOp
1717 //===----------------------------------------------------------------------===//
1718 
1719 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1720  OpAsmSetValueNameFn setNameFn) {
1721  for (Value v : getRegionInputArgs())
1722  setNameFn(v, "in");
1723  for (Value v : getRegionOutputArgs())
1724  setNameFn(v, "init");
1725 }
1726 
1727 void ReduceOp::getAsmResultNames(
1728  function_ref<void(Value, StringRef)> setNameFn) {
1729  if (!getResults().empty())
1730  setNameFn(getResults().front(), "reduced");
1731 }
1732 
1733 void ReduceOp::build(
1734  OpBuilder &builder, OperationState &result, ValueRange inputs,
1735  ValueRange inits, ArrayRef<int64_t> dimensions,
1736  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1737  ArrayRef<NamedAttribute> attributes) {
1738  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1739  result.addAttributes(attributes);
1740 
1741  // Add output types for `RankedTensorType` output arguments.
1742  for (Value init : inits) {
1743  Type initType = init.getType();
1744  if (llvm::isa<RankedTensorType>(initType))
1745  result.addTypes(initType);
1746  }
1747 
1748  if (bodyBuild)
1749  buildGenericRegion(builder, result.location, *result.regions.front(),
1750  inputs, inits, bodyBuild);
1751 }
1752 
1753 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1754  int64_t inputRank =
1755  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1756  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1757  utils::IteratorType::parallel);
1758  for (int64_t reductionDim : getDimensions())
1759  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1760  return iteratorTypes;
1761 }
1762 
1763 ArrayAttr ReduceOp::getIndexingMaps() {
1764  int64_t inputRank =
1765  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1766  SmallVector<AffineMap> affineMaps(
1767  getNumDpsInputs(),
1769  AffineMap resultMap =
1771  .dropResults(getDimensions());
1772  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1773  affineMaps.push_back(resultMap);
1774  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1775 }
1776 
1777 void ReduceOp::getEffects(
1779  &effects) {
1780  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1781 }
1782 
1783 Speculation::Speculatability ReduceOp::getSpeculatability() {
1784  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1785 }
1786 
1787 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1788  NamedAttrList &attributes,
1789  StringRef attributeName) {
1790  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1791  return failure();
1792 
1793  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1794  return success();
1795 }
1796 
1797 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1798  std::optional<OperationName> payloadOpName;
1799  NamedAttrList payloadOpAttrs;
1800  if (succeeded(parser.parseOptionalLBrace())) {
1801  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1802  if (failed(operationName))
1803  return failure();
1804  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1805  return failure();
1806  payloadOpName = operationName.value();
1807  if (parser.parseRBrace())
1808  return failure();
1809  }
1810 
1811  if (parseDstStyleOp(
1812  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1813  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1814  }))
1815  return failure();
1816 
1817  if (payloadOpName.has_value()) {
1818  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1819  ArrayRef(result.operands), /*initFirst=*/true);
1820  } else {
1822  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1823  /*allowType=*/true, /*allowAttrs=*/true)) {
1824  return failure();
1825  }
1826 
1827  Region *body = result.addRegion();
1828  if (parser.parseRegion(*body, regionArgs))
1829  return failure();
1830  }
1831 
1832  return success();
1833 }
1834 
1835 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1836  ArrayRef<int64_t> attributeValue) {
1837  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1838 }
1839 
1840 void ReduceOp::print(OpAsmPrinter &p) {
1841  Block *mapper = getBody();
1842  bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
1843  if (useShortForm) {
1844  printShortForm(p, &mapper->getOperations().front());
1845  }
1846 
1847  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1848  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1849  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1850  if (!useShortForm) {
1851  // Print region if the payload op was not detected.
1852  p.increaseIndent();
1853  p.printNewline();
1854  p << "(";
1855  llvm::interleaveComma(mapper->getArguments(), p,
1856  [&](auto arg) { p.printRegionArgument(arg); });
1857  p << ") ";
1858 
1859  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1860  p.decreaseIndent();
1861  }
1862 }
1863 
1864 LogicalResult ReduceOp::verify() {
1865  ArrayRef<int64_t> dimensionsRef = getDimensions();
1866 
1867  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1868  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1869  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1870  return emitOpError() << "expects all inputs to have the same shapes. "
1871  "Shape at input-index "
1872  << i
1873  << " is not equal to the shape at input-index 0.";
1874  }
1875  }
1876  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1877  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1878  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1879  return emitOpError() << "expects all outputs to have the same shapes. "
1880  "Shape at output-index "
1881  << i
1882  << " is not equal to the shape at output-index 0.";
1883  }
1884  }
1885  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1886  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1887 
1888  DenseSet<int64_t> dimensionsToReduce;
1889  for (int64_t dimension : dimensionsRef) {
1890  if (dimension < 0 || dimension >= inputType.getRank()) {
1891  return emitOpError()
1892  << "dimensions for reduction should be in the range [0, "
1893  << inputType.getRank() - 1 << "].";
1894  }
1895  dimensionsToReduce.insert(dimension);
1896  }
1897 
1898  auto inputDims = inputType.getShape();
1899  auto initDims = initType.getShape();
1900 
1901  // Input dimensions that will be left after the reduction.
1902  SmallVector<int64_t> reducedInputDims;
1903  for (const auto &en : llvm::enumerate(inputDims)) {
1904  if (!dimensionsToReduce.count(en.index()))
1905  reducedInputDims.push_back(en.value());
1906  }
1907 
1908  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1909  return emitOpError() << "number of dimensions after reduction "
1910  << reducedInputDims.size()
1911  << " doesn't match the init rank "
1912  << initType.getRank();
1913  }
1914 
1915  if (reducedInputDims != initDims)
1916  return emitOpError() << "init dimensions [" << initDims
1917  << "] doesn't match input dimensions after reduction ["
1918  << reducedInputDims << "]";
1919 
1920  Block *block = getBody();
1921  if (block->getNumArguments() != this->getNumOperands())
1922  return emitOpError()
1923  << "mismatching number of operands and block arguments";
1924 
1925  // Check that the first block arguments match the element type of the inputs.
1926  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1927  Type inputElementType =
1928  llvm::cast<ShapedType>(input.getType()).getElementType();
1929  if (inputElementType != bbArg.getType())
1930  return emitOpError()
1931  << "input element type " << inputElementType
1932  << " does not match corresponding block argument type "
1933  << bbArg.getType();
1934  }
1935 
1936  // Check that the last block arguments match the element type of the outputs.
1937  for (auto [output, bbArg] : llvm::zip(
1938  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1939  auto outputElementType =
1940  llvm::cast<ShapedType>(output.getType()).getElementType();
1941  if (outputElementType != bbArg.getType())
1942  return emitOpError()
1943  << "output element type " << outputElementType
1944  << " does not match corresponding block argument type "
1945  << bbArg.getType();
1946  }
1947  return success();
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // TransposeOp
1952 //===----------------------------------------------------------------------===//
1953 
1954 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1955  Region &region, ValueRange inputs,
1956  ValueRange outputs) {
1957  buildGenericRegion(builder, loc, region, inputs, outputs,
1958  [](OpBuilder &b, Location loc, ValueRange args) {
1959  if (!args.empty())
1960  linalg::YieldOp::create(b, loc, args[0]);
1961  });
1962 }
1963 
1964 void TransposeOp::build(::mlir::OpBuilder &builder,
1965  ::mlir::OperationState &result, Value input, Value init,
1966  DenseI64ArrayAttr permutation,
1967  ArrayRef<NamedAttribute> attributes) {
1968  result.addOperands(input);
1969  result.addOperands(init);
1970  result.addAttribute(getPermutationAttrName(result.name), permutation);
1971  result.addAttributes(attributes);
1972 
1973  // Add output types for `RankedTensorType` output arguments.
1974  Type initType = init.getType();
1975  if (llvm::isa<RankedTensorType>(initType))
1976  result.addTypes(initType);
1977 
1978  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1979  init);
1980 }
1981 
1982 void TransposeOp::build(::mlir::OpBuilder &builder,
1983  ::mlir::OperationState &result, Value input, Value init,
1984  ArrayRef<int64_t> permutation,
1985  ArrayRef<NamedAttribute> attributes) {
1986  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1987  attributes);
1988 }
1989 
1990 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1991  if (failed(parseDstStyleOp(
1992  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1993  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1994  })))
1995  return failure();
1996 
1997  OpBuilder builder(parser.getContext());
1998  buildIdentityRegion(builder, result.location, *result.addRegion(),
1999  /*inputs=*/result.operands,
2000  /*outputs=*/{});
2001  return success();
2002 }
2003 
2004 void TransposeOp::getAsmResultNames(
2005  function_ref<void(Value, StringRef)> setNameFn) {
2006  if (!getResults().empty())
2007  setNameFn(getResults().front(), "transposed");
2008 }
2009 
2011  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2012  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2013  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2014 }
2015 
2016 LogicalResult TransposeOp::verify() {
2017  ArrayRef<int64_t> permutationRef = getPermutation();
2018 
2019  if (!isPermutationVector(permutationRef))
2020  return emitOpError("permutation is not valid");
2021 
2022  auto inputType = getInput().getType();
2023  auto initType = getInit().getType();
2024 
2025  int64_t rank = inputType.getRank();
2026 
2027  if (rank != initType.getRank())
2028  return emitOpError() << "input rank " << rank
2029  << " does not match init rank " << initType.getRank();
2030 
2031  if (rank != static_cast<int64_t>(permutationRef.size()))
2032  return emitOpError() << "size of permutation " << permutationRef.size()
2033  << " does not match the argument rank " << rank;
2034 
2035  auto inputDims = inputType.getShape();
2036  auto initDims = initType.getShape();
2037 
2038  for (int64_t i = 0; i < rank; ++i) {
2039  int64_t inputDim = inputDims[permutationRef[i]];
2040  int64_t initDim = initDims[i];
2041 
2042  if (inputDim != initDim) {
2043  return emitOpError() << "dim(result, " << i << ") = " << initDim
2044  << " doesn't match dim(input, permutation[" << i
2045  << "]) = " << inputDim;
2046  }
2047  }
2048 
2049  return success();
2050 }
2051 
2052 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2053  int64_t rank = getInit().getType().getRank();
2054  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2055 }
2056 
2057 ArrayAttr TransposeOp::getIndexingMaps() {
2058  Builder builder(getContext());
2059  int64_t rank = getInit().getType().getRank();
2060  return builder.getAffineMapArrayAttr(
2062  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2063  builder.getMultiDimIdentityMap(rank)});
2064 }
2065 
2066 void TransposeOp::getEffects(
2068  &effects) {
2069  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2070 }
2071 
2072 Speculation::Speculatability TransposeOp::getSpeculatability() {
2073  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2074 }
2075 
2076 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2078  // Only the tensor type is supported.
2079  if (!isa<TensorType>(getInput().getType()))
2080  return failure();
2081 
2082  // Single dimension transpose.
2083  if (getPermutation().size() == 0) {
2084  result.push_back(getInput());
2085  return success();
2086  }
2087  // Identity permutation.
2088  if (isIdentityPermutation(getPermutation())) {
2089  result.push_back(getInput());
2090  return success();
2091  }
2092 
2093  return failure();
2094 }
2095 
2096 /// Fold transpose with transpose.
2097 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2099 
2100  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2101  PatternRewriter &rewriter) const override {
2102  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2103  if (!defTransposeOp)
2104  return failure();
2105  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2106  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2107  SmallVector<int64_t> foldedPerms;
2108  foldedPerms.reserve(perms.size());
2109  for (int64_t perm : perms)
2110  foldedPerms.push_back(defPerms[perm]);
2111 
2112  rewriter.replaceOpWithNewOp<TransposeOp>(
2113  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2114  foldedPerms);
2115  return success();
2116  }
2117 };
2118 
2119 /// This pattern canonicalize transpose by swapping the order of
2120 /// broadcast and transpose:
2121 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2122 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2124 
2125  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2126  PatternRewriter &rewriter) const override {
2127  Value input = transposeOp.getInput();
2128  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2129  if (!input.hasOneUse() || !broadcastOp)
2130  return failure();
2131 
2132  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2133  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2134 
2135  // Get new perms and new dimensions.
2136  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2137  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2138  SmallVector<int64_t> resultDimensions;
2139  unsigned dimensionSize = dimensions.size();
2140  for (unsigned i = 0; i < dimensionSize; ++i)
2141  resultDimensions.push_back(invertPerm[dimensions[i]]);
2142 
2143  // Create transpose result.
2144  Value broadcastInput = broadcastOp.getInput();
2145  Location loc = transposeOp.getLoc();
2146  MLIRContext *ctx = transposeOp.getContext();
2148  auto broadcastInputTy =
2149  mlir::cast<RankedTensorType>(broadcastInput.getType());
2150  unsigned inputRank = broadcastInputTy.getRank();
2151  for (unsigned i = 0; i < inputRank; ++i) {
2152  if (broadcastInputTy.isDynamicDim(i)) {
2153  dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2154  ->getResult(0));
2155  } else {
2156  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2157  broadcastInputTy.getDimSize(i)));
2158  }
2159  }
2160  SmallVector<OpFoldResult> transposeResultShapes =
2161  applyPermutation(dims, resultPerms);
2162  Value transposeInit = tensor::EmptyOp::create(
2163  rewriter, transposeOp.getLoc(), transposeResultShapes,
2164  broadcastInputTy.getElementType());
2165 
2166  // Create broadcast(transpose(input)).
2167  Value transposeResult =
2168  TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2169  transposeInit, resultPerms)
2170  ->getResult(0);
2171  rewriter.replaceOpWithNewOp<BroadcastOp>(
2172  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2173  return success();
2174  }
2175 };
2176 
2177 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2178  MLIRContext *context) {
2180 }
2181 
2182 //===----------------------------------------------------------------------===//
2183 // BroadcastOp
2184 //===----------------------------------------------------------------------===//
2185 
2186 void BroadcastOp::build(::mlir::OpBuilder &builder,
2187  ::mlir::OperationState &result, Value input, Value init,
2188  DenseI64ArrayAttr dimensions,
2189  ArrayRef<NamedAttribute> attributes) {
2190  result.addOperands(input);
2191  result.addOperands(init);
2192  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2193  result.addAttributes(attributes);
2194 
2195  // Add output types for `RankedTensorType` output arguments.
2196  Type initType = init.getType();
2197  if (llvm::isa<RankedTensorType>(initType))
2198  result.addTypes(initType);
2199 
2200  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2201  init);
2202 }
2203 
2204 void BroadcastOp::build(::mlir::OpBuilder &builder,
2205  ::mlir::OperationState &result, Value input, Value init,
2206  ArrayRef<int64_t> dimensions,
2207  ArrayRef<NamedAttribute> attributes) {
2208  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2209  attributes);
2210 }
2211 
2212 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2213  if (failed(parseDstStyleOp(
2214  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2215  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2216  })))
2217  return failure();
2218 
2219  OpBuilder builder(parser.getContext());
2220  buildIdentityRegion(builder, result.location, *result.addRegion(),
2221  /*inputs=*/result.operands,
2222  /*outputs=*/{});
2223  return success();
2224 }
2225 
2226 void BroadcastOp::getAsmResultNames(
2227  function_ref<void(Value, StringRef)> setNameFn) {
2228  if (!getResults().empty())
2229  setNameFn(getResults().front(), "broadcasted");
2230 }
2231 
2233  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2234  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2235  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2236 }
2237 
2238 LogicalResult BroadcastOp::verify() {
2239  ArrayRef<int64_t> dimensionsRef = getDimensions();
2240 
2241  auto inputType = getInput().getType();
2242  auto initType = getInit().getType();
2243 
2244  int64_t inputRank = inputType.getRank();
2245  int64_t initRank = initType.getRank();
2246 
2247  auto inputShape = inputType.getShape();
2248  auto initShape = initType.getShape();
2249 
2250  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2251  return emitOpError() << "input rank plus added dimensions does not "
2252  "match init rank. input rank: "
2253  << inputRank
2254  << ", dimensions size: " << dimensionsRef.size()
2255  << ", init rank: " << initRank;
2256 
2257  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2258  if (dim < 0 || dim >= initRank)
2259  return emitOpError() << "dimension " << idx
2260  << " is out of range. expected range: [0, "
2261  << initRank - 1 << "], got: " << dim;
2262  }
2263 
2264  // Mapping from input dims to init dims.
2265  SmallVector<int64_t> dimMap;
2266  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2267  if (!llvm::is_contained(dimensionsRef, dim))
2268  dimMap.push_back(dim);
2269  }
2270 
2271  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2272  // This dimensions is mapped from the input. Init and input dims should
2273  // match.
2274  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2275  return emitOpError() << "input dim " << inputDimIdx
2276  << " should match init dim " << initDimIdx
2277  << ". input: " << inputShape[inputDimIdx]
2278  << ", init: " << initShape[initDimIdx];
2279  }
2280 
2281  return success();
2282 }
2283 
2284 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2285  int64_t rank = getInit().getType().getRank();
2286  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2287 }
2288 
2289 ArrayAttr BroadcastOp::getIndexingMaps() {
2290  Builder builder(getContext());
2291  int64_t rank = getInit().getType().getRank();
2292  return builder.getAffineMapArrayAttr(
2293  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2294  builder.getMultiDimIdentityMap(rank)});
2295 }
2296 
2297 void BroadcastOp::getEffects(
2299  &effects) {
2300  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2301 }
2302 
2303 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2304  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2305 }
2306 
2307 /// Fold back-to-back broadcasts together.
2308 struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
2310 
2311  LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
2312  PatternRewriter &rewriter) const override {
2313  auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
2314  if (!defBroadcastOp)
2315  return failure();
2316  ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
2317  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2318  SmallVector<int64_t> foldedDims(dimensions);
2319  Value init = broadcastOp.getInit();
2320  int64_t initRank = cast<ShapedType>(init.getType()).getRank();
2321  // Mapping from input dims to init dims.
2322  SmallVector<int64_t> dimMap;
2323  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2324  if (!llvm::is_contained(dimensions, dim))
2325  dimMap.push_back(dim);
2326  }
2327  for (auto dim : defDimensions)
2328  foldedDims.push_back(dimMap[dim]);
2329 
2330  llvm::sort(foldedDims);
2331  rewriter.replaceOpWithNewOp<BroadcastOp>(
2332  broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
2333  return success();
2334  }
2335 };
2336 
2337 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2338  MLIRContext *context) {
2339  results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // YieldOp
2344 //===----------------------------------------------------------------------===//
2345 
2347  if (getNumOperands() > 0)
2348  p << ' ' << getOperands();
2349  p.printOptionalAttrDict((*this)->getAttrs());
2350  if (getNumOperands() > 0)
2351  p << " : " << getOperandTypes();
2352 }
2353 
2354 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2356  SmallVector<Type, 2> types;
2357  SMLoc loc = parser.getCurrentLocation();
2358  return failure(parser.parseOperandList(opInfo) ||
2359  parser.parseOptionalAttrDict(result.attributes) ||
2360  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2361  parser.resolveOperands(opInfo, types, loc, result.operands));
2362 }
2363 
2364 // Check the operand number and types must match the element types of the
2365 // LinalgOp interface's shaped operands.
2366 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2367  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2368  return op.emitOpError("expected number of yield values (")
2369  << op.getNumOperands()
2370  << ") to match the number of inits / outs operands of the enclosing "
2371  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2372 
2373  for (OpOperand &opOperand : op->getOpOperands()) {
2374  OpOperand *outputOperand =
2375  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2376  Type elementType = outputOperand->get().getType();
2377  if (isa<MemRefType, RankedTensorType>(elementType))
2378  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2379  if (opOperand.get().getType() != elementType)
2380  return op.emitOpError("type of yield operand ")
2381  << (opOperand.getOperandNumber() + 1) << " ("
2382  << opOperand.get().getType() << ") doesn't match "
2383  << "the element type of the enclosing linalg.generic op ("
2384  << elementType << ")";
2385  }
2386  return success();
2387 }
2388 
2389 LogicalResult linalg::YieldOp::verify() {
2390  auto *parentOp = (*this)->getParentOp();
2391  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2392  return emitOpError("expected single non-empty parent region");
2393 
2394  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2395  return verifyYield(*this, linalgOp);
2396 
2397  return emitOpError("expected parent op with LinalgOp interface");
2398 }
2399 
2400 //===----------------------------------------------------------------------===//
2401 // IndexOp
2402 //===----------------------------------------------------------------------===//
2403 
2404 LogicalResult IndexOp::verify() {
2405  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2406  if (!linalgOp)
2407  return emitOpError("expected parent op with LinalgOp interface");
2408  if (linalgOp.getNumLoops() <= getDim())
2409  return emitOpError("expected dim (")
2410  << getDim() << ") to be lower than the number of loops ("
2411  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2412  return success();
2413 }
2414 
2415 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2416  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2417  // Bail out if `linalg.index` does not have a proper parent yet at this
2418  // point, e.g., when calling `createOrFold` during IR construction in
2419  // `genericOp::build`.
2420  if (!linalgOp)
2421  return OpFoldResult{};
2422 
2423  // Index of unit dims is always 0.
2424  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2425  uint64_t dim = getDim();
2426  assert(dim < loopBounds.size() && "Dim is out of bounds");
2427  if (loopBounds[dim] == 1)
2429 
2430  return OpFoldResult{};
2431 }
2432 
2433 /////// Operations corresponding to library calls defined with Tablegen ////////
2434 
2435 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2436 
2437 #define GET_OP_CLASSES
2438 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2439 
2440 #define GET_OP_CLASSES
2441 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2442 #define GET_OP_CLASSES
2443 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2444 
2445 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2446  unsigned rank,
2447  MLIRContext *context) {
2448  if (maybeMap)
2449  return *maybeMap;
2450  if (rank == 0)
2451  return AffineMap::get(context);
2452  return AffineMap::getMultiDimIdentityMap(rank, context);
2453 }
2454 
2456 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2457  MLIRContext *context) {
2459  res.reserve(num);
2460  for (unsigned i = 0; i < num; ++i)
2461  res.push_back(getAffineDimExpr(startIdx++, context));
2462  return res;
2463 }
2464 
2467  auto rangeA = llvm::make_range(a.begin(), a.end());
2468  auto rangeB = llvm::make_range(b.begin(), b.end());
2469  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2470  return llvm::to_vector<4>(concatRanges);
2471 }
2472 
2473 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2474  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2475  ss << "view";
2476  for (auto size : memref.getShape())
2477  if (size < 0)
2478  ss << "sx";
2479  else
2480  ss << size << "x";
2481  if (failed(appendMangledType(ss, memref.getElementType())))
2482  return failure();
2483  if (auto as = memref.getMemorySpace()) {
2484  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2485  ss << "as" << attr.getInt();
2486  else
2487  return failure();
2488  }
2489  return success();
2490  }
2491  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2492  ss << "vector";
2493  llvm::interleave(
2494  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2495  if (failed(appendMangledType(ss, vec.getElementType())))
2496  return failure();
2497  return success();
2498  }
2499  if (t.isSignlessIntOrIndexOrFloat()) {
2500  ss << t;
2501  return success();
2502  }
2503  return failure();
2504 }
2505 
2507  assert(isa<LinalgOp>(op));
2508  std::string name(op->getName().getStringRef().str());
2509  std::string fun = "";
2510  for (NamedAttribute kv : op->getAttrs()) {
2511  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2512  fun = stringifyEnum(ufa.getValue()).str() + "_";
2513  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2514  fun = stringifyEnum(bfa.getValue()).str() + "_";
2515  }
2516  }
2517  name.reserve(128);
2518  llvm::replace(name, '.', '_');
2519  llvm::raw_string_ostream ss(name);
2520  ss << "_" << fun;
2521  for (Type t : op->getOperandTypes()) {
2522  if (failed(appendMangledType(ss, t)))
2523  return std::string();
2524  ss << "_";
2525  }
2526  name.pop_back();
2527  return name;
2528 }
2529 
2530 //===----------------------------------------------------------------------===//
2531 // Canonicalizers and Folders.
2532 //===----------------------------------------------------------------------===//
2533 
2534 namespace {
2535 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2537 
2538  LogicalResult matchAndRewrite(LinalgOp op,
2539  PatternRewriter &rewriter) const override {
2540  for (OpOperand &opOperand : op->getOpOperands()) {
2541  // Linalg "inputs" may be either tensor or memref type.
2542  // tensor<0xelt_type> is a convention that may not always mean
2543  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2544  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2545  if (!mt)
2546  continue;
2547  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2548  rewriter.eraseOp(op);
2549  return success();
2550  }
2551  }
2552  return failure();
2553  }
2554 };
2555 
2556 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2557 /// result that is more static than the linalg op.
2558 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2560 
2561  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2562  PatternRewriter &rewriter) const override {
2563  if (!tensor::canFoldIntoProducerOp(castOp))
2564  return failure();
2565 
2566  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2567  if (!linalgOp)
2568  return failure();
2569 
2570  // Cast can be in conditionally reachable region, if which case folding will
2571  // generate invalid code. Only conservatively fold ops in same block for
2572  // now.
2573  if (castOp->getBlock() != linalgOp->getBlock())
2574  return failure();
2575 
2576  OpBuilder::InsertionGuard guard(rewriter);
2577  rewriter.setInsertionPoint(linalgOp);
2578 
2579  Location loc = linalgOp.getLoc();
2580  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2581  unsigned resultNumber = resultValue.getResultNumber();
2582  auto resultType =
2583  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2584  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2585  // going from a more dynamic shape to a less dynamic shape. If the producer
2586  // for this cast, i.e. producer of the out operand, is also an operation
2587  // that folds with tensor.cast consumer (like this pattern), the cast will
2588  // continue to propagate as far up the stack as it can go.
2589  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2590  Value newOperand =
2591  tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2592  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2593  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2594  linalgOp.getDpsInits().end());
2595  outputOperands[resultNumber] = newOperand;
2596  newOperands.append(outputOperands.begin(), outputOperands.end());
2597 
2598  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2599  linalgOp->result_type_end());
2600  resultTypes[resultNumber] = resultType;
2601  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2602 
2603  // Create a tensor.cast operation back to the original type.
2604  Value castBack = tensor::CastOp::create(
2605  rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2606 
2607  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2608  results[resultNumber] = castBack;
2609  rewriter.replaceOp(linalgOp, results);
2610  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2611  return success();
2612  }
2613 };
2614 
2615 /// For each of the operand in `operands` this function maps the static sizes of
2616 /// dimensions to their affine dim expressions.
2617 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2618  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2619  for (OpOperand &opOperand : operands) {
2620  if (linalgOp.isScalar(&opOperand))
2621  continue;
2622  Value src = opOperand.get();
2623  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2624  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2625 
2626  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2627  // `tensor.cast` operation and source of the cast operation has a static
2628  // shape, then assign it to the `sourceShape`.
2629  auto *parentOp = src.getDefiningOp();
2630  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2631  if (parentOp) {
2632  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2633  Value castSource = castOp.getSource();
2634  auto castSourceType =
2635  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2636  if (castSourceType && castSourceType.hasStaticShape())
2637  sourceShape = castSourceType.getShape();
2638  }
2639  }
2640 
2641  // If the source shape's dimension has a static shape, map the affine dim
2642  // expression to the known static size.
2643  for (unsigned i = 0; i < sourceShape.size(); i++) {
2644  if (sourceType.isDynamicDim(i))
2645  continue;
2646  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2647  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2648  }
2649  }
2650 }
2651 
2652 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2653 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2654 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2655 /// change then `changeNeeded` is false and same operand is added in the
2656 /// `newOperands` list.
2657 static void createNewOperandWithStaticSizes(
2658  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2659  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2660  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2661  bool &changeNeeded) {
2662  Value src = opOperand->get();
2663  newOperands.push_back(src);
2664  if (linalgOp.isScalar(opOperand))
2665  return;
2666  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2667  Type resultType = sourceType;
2668  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2669  resultTypes.push_back(resultType);
2670  return;
2671  }
2672  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2673  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2674  SmallVector<int64_t> newShape;
2675  // If operand is updated with new shape, `newOperandNeeded` will be
2676  // true.
2677  bool newOperandNeeded = false;
2678  for (unsigned i = 0; i < sourceShape.size(); i++) {
2679  int64_t dimShape = sourceShape[i];
2680  AffineExpr dimExpr = sourceMap.getResult(i);
2681  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2682  newShape.push_back(dimShape);
2683  continue;
2684  }
2685  // Dimension has a dynamic shape and corresponding affine dim
2686  // expression is present in the map. So assign the size for the
2687  // given affine dim expression to the dimension.
2688  newShape.push_back(affineExprToSize[dimExpr]);
2689  newOperandNeeded = true;
2690  }
2691  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2692  sourceType.getEncoding());
2693  if (newOperandNeeded) {
2694  changeNeeded = true;
2695  // Get the new operand value given its size and element type by
2696  // casting it.
2697  Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2698  unsigned index = opOperand->getOperandNumber();
2699  newOperands[index] = newOperand;
2700  }
2701  if (linalgOp.isDpsInit(opOperand))
2702  resultTypes.push_back(resultType);
2703 }
2704 
2705 /// Static shapes for the operands can be inferred if any one of the operands
2706 /// have a static shape. This can be done by referring to the affine dim
2707 /// expressions for the operand.
2708 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2710 
2711  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2712  PatternRewriter &rewriter) const override {
2713  if (!linalgOp.hasPureTensorSemantics())
2714  return failure();
2715 
2716  // Maps must be projected permutations.
2717  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2718  return !map.isProjectedPermutation();
2719  }))
2720  return failure();
2721 
2722  // Maps affine dim expressions to the static size of that dimension.
2723  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2724  Location loc = linalgOp.getLoc();
2725 
2726  // For each of the affine dim expression, check if the size is known. If
2727  // known add that in the map.
2728  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2729 
2730  SmallVector<Value> newOperands;
2731  SmallVector<Type> resultTypes;
2732 
2733  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2734  // change in their types.
2735  bool changeNeeded = false;
2736  newOperands.reserve(linalgOp->getNumOperands());
2737  resultTypes.reserve(linalgOp.getNumDpsInits());
2738 
2739  // Iterate over all the operands and update the static sizes.
2740  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2741  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2742  affineExprToSize, linalgOp, newOperands,
2743  resultTypes, changeNeeded);
2744  }
2745 
2746  // If the generic op has all the required static information, no
2747  // canonicalization needed.
2748  if (!changeNeeded)
2749  return failure();
2750 
2751  // Clone op.
2752  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2753  SmallVector<Value> replacements;
2754  replacements.reserve(newOp->getNumResults());
2755  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2756  Value newResult = std::get<1>(it);
2757  Value oldResult = std::get<0>(it);
2758  Type newType = newResult.getType();
2759  Type oldType = oldResult.getType();
2760  replacements.push_back(
2761  (newType != oldType)
2762  ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2763  : newResult);
2764  }
2765  rewriter.replaceOp(linalgOp, replacements);
2766  return success();
2767  }
2768 };
2769 
2770 } // namespace
2771 
2772 // All named ops canonicalizers and folders are auto-generated in the
2773 // .cpp.inc.
2774 
2775 //===----------------------------------------------------------------------===//
2776 // SoftmaxOp
2777 //===----------------------------------------------------------------------===//
2778 
2779 LogicalResult SoftmaxOp::verify() {
2780  ShapedType inputType = getInputOperandType();
2781  ShapedType outputType = getOutputOperandType();
2782 
2783  ArrayRef<int64_t> inputShape = inputType.getShape();
2784  ArrayRef<int64_t> outputShape = outputType.getShape();
2785  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2786  return emitOpError("incompatible output shape");
2787 
2788  int64_t inputRank = getInputOperandRank();
2789  int64_t dimension = getDimension();
2790  if ((dimension < 0) || (dimension >= inputRank))
2791  return emitOpError("incorrect dimension specified");
2792 
2793  return success();
2794 }
2795 
2796 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2797  int64_t operandRank = getInputOperandRank();
2798  SmallVector<Range> loopBounds(operandRank);
2799  Location loc = getLoc();
2800  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2801  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2802  Value source = getInput();
2803  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2804  loopBounds[dim].offset = zero;
2805  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2806  loopBounds[dim].stride = one;
2807  }
2808  return loopBounds;
2809 }
2810 
2811 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2812  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2813  utils::IteratorType::parallel);
2814  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2815  return iteratorTypes;
2816 }
2817 
2818 FailureOr<TilingResult>
2820  ArrayRef<OpFoldResult> offsets,
2821  ArrayRef<OpFoldResult> sizes) {
2822  int64_t rank = getInputOperandRank();
2823  auto oneAttr = builder.getI64IntegerAttr(1);
2824  SmallVector<OpFoldResult> strides(rank, oneAttr);
2825  SmallVector<Value> tiledOperands;
2826  Operation *inputSlice =
2827  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2828  if (!inputSlice) {
2829  return emitOpError("failed to compute input slice");
2830  }
2831  tiledOperands.emplace_back(inputSlice->getResult(0));
2832  Operation *outputSlice =
2833  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2834  if (!outputSlice) {
2835  return emitOpError("failed to compute output slice");
2836  }
2837  tiledOperands.emplace_back(outputSlice->getResult(0));
2838 
2839  SmallVector<Type, 4> resultTypes;
2840  if (hasPureTensorSemantics())
2841  resultTypes.push_back(tiledOperands[1].getType());
2842  Operation *tiledOp =
2843  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2844 
2845  return TilingResult{
2846  {tiledOp},
2847  SmallVector<Value>(tiledOp->getResults()),
2848  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2849 }
2850 
2851 LogicalResult SoftmaxOp::getResultTilePosition(
2852  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2853  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2854  SmallVector<OpFoldResult> &resultSizes) {
2855  if (resultNumber == 0) {
2856  resultOffsets.assign(offsets.begin(), offsets.end());
2857  resultSizes.assign(sizes.begin(), sizes.end());
2858  return success();
2859  }
2860  return failure();
2861 }
2862 
2863 // cast(dynamic) -> static.
2864 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2865  return memref::foldMemRefCast(*this);
2866 }
2867 
2868 LogicalResult
2870  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2872  Location loc = getOperation()->getLoc();
2873  IRRewriter rewriter(b);
2874  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2875  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2876  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2877  if (!outputShapedType.isDynamicDim(dim)) {
2878  // Static dim: Return IntegerAttr.
2879  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2880  } else {
2881  // Dynamic dim: Return Value.
2882  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2883  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2884  }
2885  }
2886  reifiedReturnShapes.emplace_back(std::move(shapes));
2887  return success();
2888 }
2889 
2890 void SoftmaxOp::getEffects(
2892  &effects) {
2893  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2894  if (!llvm::isa<MemRefType>(operand.getType()))
2895  continue;
2896  effects.emplace_back(MemoryEffects::Read::get(),
2897  &getOperation()->getOpOperand(index), /*stage=*/0,
2898  /*effectOnFullRegion=*/true,
2900  }
2901 
2902  for (OpOperand &operand : getDpsInitsMutable()) {
2903  if (!llvm::isa<MemRefType>(operand.get().getType()))
2904  continue;
2905  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2906  /*effectOnFullRegion=*/true,
2908  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2909  /*effectOnFullRegion=*/true,
2911  }
2912 }
2913 
2914 // Helper functions for softmax decomposition.
2915 // @{
2916 
2917 // Helper function to produce the iterator types (reduction or parallel) and
2918 // affine maps for the iterators used in the decomposition of softmax.
2919 // This method creates:
2920 // If allParallel == true:
2921 // - iterator type: {parallel, ..., parallel}
2922 // - affine maps:
2923 // -- identity with inputRank dimensions.
2924 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2925 // where N == inputRank.
2926 //
2927 // If allParallel == false:
2928 // - iterator type at dim(i) == parallel for i != \p dim and
2929 // dim(dim) == reduction.
2930 // - affine map:
2931 // -- identity with inputRank dimensions.
2932 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2933 // where N == inputRank.
2934 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2936  int64_t dim, bool allParallel = false) {
2937  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2938  utils::IteratorType::parallel);
2939  if (!allParallel)
2940  iteratorTypes[dim] = utils::IteratorType::reduction;
2941  MLIRContext *ctxt = builder.getContext();
2942  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2943  SmallVector<AffineExpr, 2> affineExprs;
2944  for (int i = 0; i < inputRank; i++) {
2945  if (i != dim)
2946  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2947  }
2948  auto reductionMap =
2949  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2950  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2951  return std::make_tuple(iteratorTypes, indexingMaps);
2952 }
2953 
2954 // Helper function to produce a linalg.generic that computes a reduction on
2955 // dimension \p dim with the operation type \p T.
2956 template <typename T>
2957 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2958  int64_t dim) {
2959  auto inputType = cast<ShapedType>(input.getType());
2960  ArrayRef<int64_t> inputShape = inputType.getShape();
2961  int64_t inputRank = inputShape.size();
2962  auto [iteratorTypes, indexingMaps] =
2963  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2964  assert(indexingMaps.size() == 2 &&
2965  "We should have two maps: 1 for the input, 1 for the output");
2966  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2967 
2968  auto genericOp = linalg::GenericOp::create(
2969  builder, loc, output.getType(), input, output, indexingMaps,
2970  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2971  Value result = T::create(b, loc, args[0], args[1]);
2972  linalg::YieldOp::create(b, loc, result);
2973  });
2974  return genericOp.getResult(0);
2975 }
2976 
2977 /// Produce a linalg generic that computes the second step of the softmax
2978 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2979 /// on dimension \p dim.
2980 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2981  Value max, Value output, int64_t dim) {
2982  auto inputType = cast<ShapedType>(input.getType());
2983  ArrayRef<int64_t> inputShape = inputType.getShape();
2984  int64_t inputRank = inputShape.size();
2985  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2986  builder, inputRank, dim, /*allParallel=*/true);
2987  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2988  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2989  // Add the affine map for the output argument.
2990  indexingMaps.push_back(indexingMaps[0]);
2991  auto genericOp = linalg::GenericOp::create(
2992  builder, loc, input.getType(), ValueRange{input, max}, output,
2993  indexingMaps, iteratorTypes,
2994  [&](OpBuilder &b, Location loc, ValueRange args) {
2995  Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
2996  Value result = math::ExpOp::create(b, loc, diff);
2997  linalg::YieldOp::create(b, loc, result);
2998  });
2999  return genericOp.getResult(0);
3000 }
3001 
3002 /// Produce a linalg generic that computes the final step of the softmax
3003 /// decomposition.
3004 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
3005 /// yield n / d
3006 /// }
3007 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
3008  Value denominator, Value output, int64_t dim) {
3009  auto inputType = cast<ShapedType>(numerator.getType());
3010  ArrayRef<int64_t> inputShape = inputType.getShape();
3011  int64_t inputRank = inputShape.size();
3012  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
3013  builder, inputRank, dim, /*allParallel=*/true);
3014  assert(indexingMaps.size() == 2 &&
3015  "We should have one map for each input (2)");
3016  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
3017  // Add the affine map for the output tensor.
3018  indexingMaps.push_back(indexingMaps[0]);
3019  auto genericOp = linalg::GenericOp::create(
3020  builder, loc, numerator.getType(), ValueRange{numerator, denominator},
3021  output, indexingMaps, iteratorTypes,
3022  [&](OpBuilder &b, Location loc, ValueRange args) {
3023  Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
3024  linalg::YieldOp::create(b, loc, result);
3025  });
3026  return genericOp.getResult(0);
3027 }
3028 // @} End helper functions for softmax decomposition.
3029 
3030 /// Given an N-dimensional tensor x, this method converts
3031 /// softmax(x) to the following sequence of operations:
3032 ///
3033 /// 1. Compute the max of x along dimension d. This results
3034 /// in a N-1 dimensional tensor m.
3035 /// m = max(x, dim = d)
3036 ///
3037 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
3038 /// a N dimensional tensor z.
3039 /// z = exp(x - m)
3040 ///
3041 /// 3. Compute the sum of z along dimension d. This results in
3042 /// a N-1 dimensional tensor l.
3043 /// l = sum(z, dim = d)
3044 ///
3045 /// 4. Divide z and l. This gives the N-dimensional softmax.
3046 /// softmax = z / l
3047 ///
3048 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3049  OpBuilder::InsertionGuard guard(b);
3050  b.setInsertionPoint(*this);
3051  Location loc = getLoc();
3052  Value input = getInput();
3053  ShapedType inputType = getInputOperandType();
3054  Type elementType = inputType.getElementType();
3055  int64_t reductionDim = getDimension();
3056  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3057  Value output = getOutput();
3058  dims.erase(dims.begin() + reductionDim);
3059  // Step 1: Compute max along dim.
3060  Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3061  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3062  elementType, b, loc,
3063  /*useOnlyFiniteValue=*/true);
3064  Value neutralForMaxFInit =
3065  linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3066  .result();
3067  Value max =
3068  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3069 
3070  // Step 2: Subtract max from input and exponentiate.
3071  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3072 
3073  // Step 3: Compute sum along dim.
3074  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3075  b, loc, /*useOnlyFiniteValue=*/true);
3076  Value zeroInit =
3077  linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3078  Value denominator =
3079  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3080 
3081  // Step 4: Compute softmax.
3082  Value result =
3083  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3084  return SmallVector<Value>{result};
3085 }
3086 
3087 //===----------------------------------------------------------------------===//
3088 // WinogradFilterTransformOp
3089 //===----------------------------------------------------------------------===//
3090 
3091 LogicalResult WinogradFilterTransformOp::verify() {
3092  auto filterType = cast<ShapedType>(getFilter().getType());
3093  ArrayRef<int64_t> filterShape = filterType.getShape();
3094  int64_t filterH = filterShape[getFilterHDim()];
3095  int64_t filterW = filterShape[getFilterWDim()];
3096  WinogradConv2DFmr fmr = getFmr();
3097  int64_t m, r;
3098  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3099 
3100  if (filterH != r && filterH != 1)
3101  return emitOpError("expect filter height either equals to r or 1");
3102  if (filterW != r && filterW != 1)
3103  return emitOpError("expect filter width either equals to r or 1");
3104  if (filterH == 1 && filterW == 1)
3105  return emitOpError("expect either filter height or width equals to r");
3106 
3107  SmallVector<int64_t> expectedOutputShape;
3108  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3109  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3110  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3111  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3112 
3113  auto outputType = cast<ShapedType>(getOutput().getType());
3114  ArrayRef<int64_t> outputShape = outputType.getShape();
3115  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3116  return emitOpError("the output shape is not expected");
3117  }
3118  return success();
3119 }
3120 
3122 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3123  Location loc = getLoc();
3124  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3125  IntegerAttr oneAttr = builder.getIndexAttr(1);
3126  Value filter = getFilter();
3127  int64_t filterRank = getFilterOperandRank();
3128  SmallVector<Range> loopBounds(filterRank);
3129  for (unsigned dim = 0; dim < filterRank; ++dim) {
3130  loopBounds[dim].offset = zeroAttr;
3131  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3132  loopBounds[dim].stride = oneAttr;
3133  }
3134  return loopBounds;
3135 }
3136 
3138 WinogradFilterTransformOp::getLoopIteratorTypes() {
3139  int64_t filterRank = getFilterOperandRank();
3140  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3141  utils::IteratorType::parallel);
3142  return iteratorTypes;
3143 }
3144 
3146  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3147  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3148  SmallVector<OpFoldResult> &resultSizes) {
3149  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3150  ShapedType filterType = getFilterOperandType();
3151  ArrayRef<int64_t> filterShape = filterType.getShape();
3152  int64_t filterH = filterShape[getFilterHDim()];
3153  int64_t filterW = filterShape[getFilterWDim()];
3154  WinogradConv2DFmr fmr = getFmr();
3155  int64_t m, r;
3156  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3157  int64_t alpha = m + r - 1;
3158  int64_t alphaH = filterH != 1 ? alpha : 1;
3159  int64_t alphaW = filterW != 1 ? alpha : 1;
3160  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3161  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3162 
3163  resultOffsets.append(
3164  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3165  resultSizes.append(
3166  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3167 
3168  return success();
3169 }
3170 
3171 /// Implement tiling for winograd_filter_transform
3172 /// The input of winograd_filter_transform is (F, KH, KW, C).
3173 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3174 /// Users can specify the tile sizes of F and C.
3175 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3176 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3178  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3179  ArrayRef<OpFoldResult> sizes) {
3180  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3181  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3182  ShapedType filterType = getFilterOperandType();
3183  ArrayRef<int64_t> filterShape = filterType.getShape();
3184  int64_t filterH = filterShape[getFilterHDim()];
3185  int64_t filterW = filterShape[getFilterWDim()];
3186  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3187  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3188  SmallVector<Value> tiledOperands;
3189  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3190 
3191  sliceOffsets.append(
3192  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3193  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3194  sizes[getFilterCDim()]});
3195  int64_t filterRank = getFilterOperandRank();
3196  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3197  Location loc = getLoc();
3198  auto filterSlice = tensor::ExtractSliceOp::create(
3199  builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3200  tiledOperands.emplace_back(filterSlice);
3201 
3202  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3203  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3204  resultSizes)))
3205  return failure();
3206 
3207  int64_t outputRank = getOutputOperandRank();
3208  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3209  auto outputSlice = tensor::ExtractSliceOp::create(
3210  builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3211  tiledOperands.emplace_back(outputSlice);
3212 
3213  SmallVector<Type> resultTypes;
3214  resultTypes.push_back(tiledOperands[1].getType());
3215  Operation *tiledOp =
3216  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3217 
3218  return TilingResult{
3219  {tiledOp},
3220  SmallVector<Value>(tiledOp->getResults()),
3221  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3222 }
3223 
3224 //===----------------------------------------------------------------------===//
3225 // WinogradInputTransformOp
3226 //===----------------------------------------------------------------------===//
3227 
3228 LogicalResult WinogradInputTransformOp::verify() {
3229  auto inputType = cast<ShapedType>(getInput().getType());
3230  ArrayRef<int64_t> inputShape = inputType.getShape();
3231  int64_t inputH = inputShape[getInputHDim()];
3232  int64_t inputW = inputShape[getInputWDim()];
3233  WinogradConv2DFmr fmr = getFmr();
3234  int64_t m, r;
3235  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3236  int64_t tileSize = m + r - 1;
3237 
3238  auto outputType = cast<ShapedType>(getOutput().getType());
3239  ArrayRef<int64_t> outputShape = outputType.getShape();
3240  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3241  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3242 
3243  SmallVector<int64_t> expectedOutputShape(6, inputH);
3244  if (ShapedType::isDynamic(inputH)) {
3245  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3246  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3247  } else {
3248  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3249  expectedOutputShape[getOutputTileHDim()] =
3250  leftTransform ? (inputH - (r - 1)) / m : inputH;
3251  }
3252  if (ShapedType::isDynamic(inputW)) {
3253  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3254  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3255  } else {
3256  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3257  expectedOutputShape[getOutputTileWDim()] =
3258  rightTransform ? (inputW - (r - 1)) / m : inputW;
3259  }
3260  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3261  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3262 
3263  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3264  return emitOpError("the output shape is not expected");
3265  }
3266  return success();
3267 }
3268 
3270 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3271  Location loc = getLoc();
3272  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3273  IntegerAttr oneAttr = builder.getIndexAttr(1);
3274  Value output = getOutput();
3275  int64_t outputRank = getOutputOperandRank();
3276  SmallVector<Range> loopBounds(outputRank);
3277  for (unsigned dim = 0; dim < outputRank; ++dim) {
3278  loopBounds[dim].offset = zeroAttr;
3279  // alphaH, alphaW, tileH, tileW, N, C
3280  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3281  loopBounds[dim].stride = oneAttr;
3282  }
3283  return loopBounds;
3284 }
3285 
3287 WinogradInputTransformOp::getLoopIteratorTypes() {
3288  int64_t outputRank = getOutputOperandRank();
3289  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3290  utils::IteratorType::parallel);
3291  return iteratorTypes;
3292 }
3293 
3295  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3296  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3297  SmallVector<OpFoldResult> &resultSizes) {
3298  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3299  ShapedType outputType = getOutputOperandType();
3300  ArrayRef<int64_t> outputShape = outputType.getShape();
3301  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3302  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3303 
3304  WinogradConv2DFmr fmr = getFmr();
3305  int64_t m, r;
3306  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3307  int64_t alpha = m + r - 1;
3308  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3309  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3310 
3311  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3312  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3313 
3314  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3315  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3316  offsets[getOutputCDim()]});
3317  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3318  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3319  sizes[getOutputCDim()]});
3320 
3321  return success();
3322 }
3323 
3324 /// Implement tiling for winograd_input_transform
3325 /// The input of winograd_input_transform is (N, H, W, C).
3326 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3327 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3328 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3329 /// the values for the sizes of tileH, tileW, N, C for one tile.
3330 FailureOr<TilingResult>
3332  ArrayRef<OpFoldResult> offsets,
3333  ArrayRef<OpFoldResult> sizes) {
3334  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3335  WinogradConv2DFmr fmr = getFmr();
3336  int64_t m, r;
3337  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3338 
3339  ShapedType outputType = getOutputOperandType();
3340  ArrayRef<int64_t> outputShape = outputType.getShape();
3341  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3342  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3343 
3344  Location loc = getLoc();
3345  MLIRContext *context = builder.getContext();
3346  auto identityAffineMap =
3347  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3348  auto offsetAffineMap =
3349  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3350  Value mappedOffsetH = affine::makeComposedAffineApply(
3351  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3352  offsets[getOutputTileHDim()]);
3353  Value mappedOffsetW = affine::makeComposedAffineApply(
3354  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3355  offsets[getOutputTileWDim()]);
3356  auto sizeAffineMap = AffineMap::get(
3357  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3358  Value mappedSizeH = affine::makeComposedAffineApply(
3359  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3360  Value mappedSizeW = affine::makeComposedAffineApply(
3361  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3362 
3363  SmallVector<Value> tiledOperands;
3364  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3365 
3366  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3367  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3368  sliceOffsets.append(
3369  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3370  OpFoldResult sizeH =
3371  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3372  OpFoldResult sizeW =
3373  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3374  sliceSizes.append(
3375  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3376  int64_t inputRank = getInputOperandRank();
3377  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3378  auto inputSlice = tensor::ExtractSliceOp::create(
3379  builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3380  tiledOperands.emplace_back(inputSlice);
3381 
3382  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3383  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3384  resultSizes)))
3385  return failure();
3386 
3387  int64_t outputRank = getOutputOperandRank();
3388  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3389  auto outputSlice = tensor::ExtractSliceOp::create(
3390  builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3391  tiledOperands.emplace_back(outputSlice);
3392 
3393  SmallVector<Type> resultTypes;
3394  resultTypes.push_back(tiledOperands[1].getType());
3395  Operation *tiledOp =
3396  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3397 
3398  return TilingResult{
3399  {tiledOp},
3400  SmallVector<Value>(tiledOp->getResults()),
3401  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3402 }
3403 
3404 //===----------------------------------------------------------------------===//
3405 // WinogradOutputTransformOp
3406 //===----------------------------------------------------------------------===//
3407 
3408 LogicalResult WinogradOutputTransformOp::verify() {
3409  auto valueType = cast<ShapedType>(getValue().getType());
3410  ArrayRef<int64_t> valueShape = valueType.getShape();
3411  int64_t valueH = valueShape[getValueAlphaHDim()];
3412  int64_t valueW = valueShape[getValueAlphaWDim()];
3413  int64_t valueTileH = valueShape[getValueTileHDim()];
3414  int64_t valueTileW = valueShape[getValueTileWDim()];
3415  WinogradConv2DFmr fmr = getFmr();
3416  int64_t m, r;
3417  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3418  bool leftTransform = valueH != 1;
3419  bool rightTransform = valueW != 1;
3420 
3421  int64_t outputRank = getOutputOperandRank();
3422  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3423  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3424  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3425  } else {
3426  if (valueH != (leftTransform ? m + r - 1 : 1))
3427  return emitOpError("expect input height equals to input tile size");
3428  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3429  }
3430  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3431  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3432  } else {
3433  if (valueW != (rightTransform ? m + r - 1 : 1))
3434  return emitOpError("expect input width equals to input tile size");
3435  expectedOutputShape[getOutputWDim()] =
3436  (rightTransform ? m : 1) * valueTileW;
3437  }
3438  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3439  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3440 
3441  auto outputType = cast<ShapedType>(getOutput().getType());
3442  ArrayRef<int64_t> outputShape = outputType.getShape();
3443  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3444  return emitOpError("the output shape is not expected");
3445  }
3446  return success();
3447 }
3448 
3450 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3451  Location loc = getLoc();
3452  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3453  IntegerAttr oneAttr = builder.getIndexAttr(1);
3454  Value value = getValue();
3455  int64_t valueRank = getValueOperandRank();
3456  SmallVector<Range> loopBounds(valueRank);
3457  for (unsigned dim = 0; dim < valueRank; ++dim) {
3458  loopBounds[dim].offset = zeroAttr;
3459  // alphaH, alphaW, tileH, tileW, N, F
3460  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3461  loopBounds[dim].stride = oneAttr;
3462  }
3463  return loopBounds;
3464 }
3465 
3467 WinogradOutputTransformOp::getLoopIteratorTypes() {
3468  int64_t valueRank = getValueOperandRank();
3469  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3470  utils::IteratorType::parallel);
3471  return iteratorTypes;
3472 }
3473 
3475  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3476  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3477  SmallVector<OpFoldResult> &resultSizes) {
3478  WinogradConv2DFmr fmr = getFmr();
3479  int64_t m, r;
3480  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3481 
3482  Location loc = getLoc();
3483  MLIRContext *context = builder.getContext();
3484  auto identityAffineMap =
3485  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3486  auto affineMap =
3487  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3488 
3489  ShapedType valueType = getValueOperandType();
3490  ArrayRef<int64_t> valueShape = valueType.getShape();
3491  int64_t valueH = valueShape[0];
3492  int64_t valueW = valueShape[1];
3493  Value mappedOffsetH = affine::makeComposedAffineApply(
3494  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3495  offsets[getValueTileHDim()]);
3496  Value mappedOffsetW = affine::makeComposedAffineApply(
3497  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3498  offsets[getValueTileWDim()]);
3499  Value mappedSizeH = affine::makeComposedAffineApply(
3500  builder, loc, affineMap, sizes[getValueTileHDim()]);
3501  Value mappedSizeW = affine::makeComposedAffineApply(
3502  builder, loc, affineMap, sizes[getValueTileWDim()]);
3503 
3504  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3505  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3506  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3507  OpFoldResult sizeH =
3508  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3509  OpFoldResult sizeW =
3510  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3511 
3512  resultOffsets.append(
3513  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3514  resultSizes.append(
3515  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3516  return success();
3517 }
3518 
3519 /// Implement tiling for winograd_output_transform
3520 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3521 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3522 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3523 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3524 /// for the sizes of tileH, tileW, N, F for one tile.
3526  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3527  ArrayRef<OpFoldResult> sizes) {
3528  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3529  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3530  Location loc = getLoc();
3531  SmallVector<Value> tiledOperands;
3532  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3533 
3534  ShapedType valueType = getValueOperandType();
3535  ArrayRef<int64_t> valueShape = valueType.getShape();
3536  int64_t alphaH = valueShape[getValueAlphaHDim()];
3537  int64_t alphaW = valueShape[getValueAlphaWDim()];
3538  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3539  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3540 
3541  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3542  offsets[getValueTileWDim()], offsets[getValueNDim()],
3543  offsets[getValueFDim()]});
3544  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3545  sizes[getValueTileWDim()], sizes[getValueNDim()],
3546  sizes[getValueFDim()]});
3547  int64_t valueRank = getValueOperandRank();
3548  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3549  auto valueSlice = tensor::ExtractSliceOp::create(
3550  builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3551  tiledOperands.emplace_back(valueSlice);
3552 
3553  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3554  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3555  resultSizes)))
3556  return failure();
3557 
3558  int64_t outputRank = getOutputOperandRank();
3559  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3560  auto outputSlice = tensor::ExtractSliceOp::create(
3561  builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3562  tiledOperands.emplace_back(outputSlice);
3563 
3564  SmallVector<Type> resultTypes;
3565  resultTypes.push_back(tiledOperands[1].getType());
3566  Operation *tiledOp =
3567  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3568 
3569  return TilingResult{
3570  {tiledOp},
3571  SmallVector<Value>(tiledOp->getResults()),
3572  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3573 }
3574 
3575 //===----------------------------------------------------------------------===//
3576 // LinalgDialect
3577 // TODO: Merge with the LinalgDialect block at the bottom
3578 //===----------------------------------------------------------------------===//
3579 
3580 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3581 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3582  auto explicitRange = subMap.getResults();
3583  auto defaultRange = fullMap.getResults();
3584  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3585  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3586  llvm::set_union(explicitSet, defaultSet);
3587  return explicitSet == defaultSet;
3588 }
3589 
3590 /// Check if the user defined map is valid broadcast map. Here broadcast
3591 /// indexing maps are defined in context of corresponding default indexing maps
3592 /// for the given Op. This way the check becomes very simple i.e just check the
3593 /// number of result dims.
3594 /// Returns true if the explictMap is broadcasted with respect to the
3595 /// defaultMap.
3596 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3597  return explictMap.getNumResults() < defaultMap.getNumResults();
3598 }
3599 
3600 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3601 /// indexing map for the MatmulOp \p op for each operand specified by \p
3602 /// opIndex.
3603 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3604  unsigned opIndex) {
3605  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3606  SmallVector<AffineMap, 3> defaultIndexingMaps =
3607  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3608 
3609  auto opIndexingMap = opIndexingMaps[opIndex];
3610  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3611  // Check general validity of indexing map results.
3612  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3613  return matmulOp->emitOpError()
3614  << "Unexpected dim expression in map result.";
3615 
3616  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3617  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3618  return matmulOp->emitOpError()
3619  << "Invalid broadcast requested, should be (d2).";
3620  }
3621  return success();
3622  }
3623  return success();
3624 }
3625 
3626 // Check general validity of input indexing map of
3627 // BatchMatmulOp/BatchReduceMatmulOp.
3628 template <typename OpTy>
3629 static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3630  AffineMap opIndexingMap,
3631  AffineMap defaultIndexingMap, bool isLHS) {
3632  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3633  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3634  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3635  // Check the result dims are valid.
3636  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3637  return batchVariantMatmulOp->emitOpError()
3638  << "Unexpected result dim expression (outside the set of default "
3639  "result dims).";
3640 
3641  // Check for valid number of result dims of input maps.
3642  if (opIndexingMap.getNumResults() > 3)
3643  return batchVariantMatmulOp->emitOpError()
3644  << "no. of result dim expressions exceeds 3.";
3645 
3646  auto hasValidBatchDim = [](AffineMap map) {
3647  AffineExpr batchDim = map.getResult(0);
3648  return batchDim.isFunctionOfDim(0);
3649  };
3650 
3651  // Check if the requested broadcast is valid.
3652  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3653  if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3654  return batchVariantMatmulOp->emitOpError()
3655  << "Invalid broadcast requested.";
3656  } else if (!hasValidBatchDim(opIndexingMap)) {
3657  return batchVariantMatmulOp->emitOpError()
3658  << "Invalid batch dimension expression.";
3659  }
3660  return success();
3661 }
3662 
3663 /// This function checks if the given AffineMap for the output of a
3664 /// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3665 /// dimensions and if the output map result dimensions are valid.
3666 template <typename OpTy>
3667 static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3668  AffineMap opIndexingMap) {
3669  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3670  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3671  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3672  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3673  opIndexingMap.getNumResults() != 3) {
3674 
3675  return batchVariantMatmulOp->emitOpError()
3676  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3677  << ").";
3678  }
3679  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3680  opIndexingMap.getNumResults() != 2) {
3681  return batchVariantMatmulOp->emitOpError()
3682  << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3683  << ").";
3684  }
3685 
3686  auto areValidOutputResultDim = [&](AffineMap outputMap) {
3687  return isa<BatchMatmulOp>(batchVariantMatmulOp)
3688  ? outputMap.getResult(0).isFunctionOfDim(0) &&
3689  outputMap.getResult(1).isFunctionOfDim(1) &&
3690  outputMap.getResult(2).isFunctionOfDim(2)
3691  : outputMap.getResult(0).isFunctionOfDim(1) &&
3692  outputMap.getResult(1).isFunctionOfDim(2);
3693  };
3694 
3695  if (!areValidOutputResultDim(opIndexingMap)) {
3696  return batchVariantMatmulOp->emitOpError()
3697  << "Invalid output map result dimension.";
3698  }
3699 
3700  return success();
3701 }
3702 
3703 /// Verifies the broadcast and transpose semantic specified by the explicit
3704 /// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3705 /// specified by opIndex.
3706 template <typename OpTy>
3707 static LogicalResult
3709  unsigned opIndex) {
3710  SmallVector<AffineMap, 3> opIndexingMaps =
3711  batchVariantMatmulOp.getIndexingMapsArray();
3712  SmallVector<AffineMap, 3> defaultIndexingMaps =
3713  batchVariantMatmulOp.getDefaultIndexingMaps(
3714  batchVariantMatmulOp->getContext());
3715 
3716  if (opIndexingMaps.size() != 3)
3717  return batchVariantMatmulOp->emitOpError()
3718  << "Indexing_map attribute must have 3 affine maps.";
3719 
3720  auto opIndexingMap = opIndexingMaps[opIndex];
3721  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3722 
3723  if (opIndex == 2 &&
3724  failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3725  return failure();
3726 
3727  if (opIndex != 2 &&
3728  failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3729  defaultIndexingMap, opIndex == 0)))
3730  return failure();
3731 
3732  return success();
3733 }
3734 
3735 namespace mlir {
3736 namespace linalg {
3737 
3738 std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3739  if (m == 2 && r == 3)
3740  return WinogradConv2DFmr::F_2_3;
3741  if (m == 4 && r == 3)
3742  return WinogradConv2DFmr::F_4_3;
3743  if (m == 2 && r == 5)
3744  return WinogradConv2DFmr::F_2_5;
3745  return std::nullopt;
3746 }
3747 
3748 std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3749  switch (fmr) {
3750  case WinogradConv2DFmr::F_2_3:
3751  return {2, 3};
3752  case WinogradConv2DFmr::F_4_3:
3753  return {4, 3};
3754  case WinogradConv2DFmr::F_2_5:
3755  return {2, 5};
3756  }
3757 }
3758 
3759 //===----------------------------------------------------------------------===//
3760 // MatMulOp
3761 //===----------------------------------------------------------------------===//
3762 
3763 static FailureOr<SmallVector<SmallVector<int64_t>>>
3764 getAffineResultPositions(ArrayAttr maps) {
3766  for (auto map : maps) {
3767  AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
3768  if (!attr)
3769  return failure();
3771  for (auto result : attr.getAffineMap().getResults()) {
3772  auto dim = dyn_cast<AffineDimExpr>(result);
3773  if (!dim)
3774  return failure();
3775  pos.push_back(dim.getPosition());
3776  }
3777  positions.push_back(pos);
3778  }
3779  return positions;
3780 }
3781 
3782 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3783 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3784  AffineExpr d0, d1, d2;
3785  SmallVector<AffineMap> indexingMaps;
3786  bindDims(context, d0, d1, d2);
3787  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3788  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3789  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3790  return indexingMaps;
3791 }
3792 
3793 bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
3794  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3795  if (!maps)
3796  return false;
3797  if (maps.size() != 3)
3798  return false;
3799  auto positions = getAffineResultPositions(maps);
3800  if (failed(positions))
3801  return false;
3802  return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
3803  (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3804  (*positions)[2] == SmallVector<int64_t>{0, 1};
3805 }
3806 
3807 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3808  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3809  utils::IteratorType::parallel,
3810  utils::IteratorType::reduction};
3811 }
3812 
3813 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3814 
3815 std::string MatmulOp::getLibraryCallName() {
3816  return generateLibraryCallName(getOperation());
3817 }
3818 
3819 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3820 
3821 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3822 /// the user defined indexing maps are not equal to default map.
3823 bool MatmulOp::hasUserDefinedMaps() {
3824  SmallVector<AffineMap, 3> defaultMaps =
3825  getDefaultIndexingMaps(this->getContext());
3826  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3827  return defaultMaps != explicitMaps;
3828 }
3829 
3830 /// Implements the block region builder for the MatmulOp. This is called by
3831 /// 'fillStructuredOpRegion'.
3832 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3835  if (emitError && block.getNumArguments() != 3) {
3836  emitError() << "MatmulOp regionBuilder expects 3 args, got "
3837  << block.getNumArguments();
3838  return;
3839  }
3840  assert(block.getNumArguments() == 3 &&
3841  "MatmulOp regionBuilder expects 3 args");
3842  RegionBuilderHelper helper(b, block);
3843  SmallVector<Value> yields;
3844 
3845  TypeFn castVal = TypeFn::cast_signed;
3846  const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3847  return attr.getName() == "cast";
3848  });
3849  if (castIter != attrs.end()) {
3850  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3851  castVal = attr.getValue();
3852  }
3853 
3854  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3855  block.getArgument(0));
3856  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3857  block.getArgument(1));
3858  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3859  if (!value3)
3860  return;
3861  Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3862  value3, emitError);
3863  if (!value4)
3864  return;
3865  yields.push_back(value4);
3866  helper.yieldOutputs(yields);
3867 }
3868 
3869 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
3870 /// broadcast map must include K dimension.
3871 /// TODO: Strict inclusion of K dimension in the broadcast map is not
3872 /// necessary for both input matrices simultaneously. We can relax this
3873 /// condition to have K dimension for one input matrix map and infer the K
3874 /// dimension for other input matrix map from the one already having K
3875 /// dimension.
3876 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3877  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3878  AffineExpr expr = bcastMap.getResult(0);
3879  // Invalid map if the common dimension of matmul not found.
3880  return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3881 }
3882 
3883 static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3884  if (parser.parseOptionalKeyword("indexing_maps"))
3885  return ArrayAttr{
3886  nullptr}; // Success in case indexing_maps was not provided.
3887 
3888  ArrayAttr arrayAttr;
3889  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3890  return failure();
3891 
3892  if (llvm::any_of(arrayAttr,
3893  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3894  return parser.emitError(parser.getCurrentLocation())
3895  << "element of indexing_maps array is not an affine_map";
3896 
3897  return arrayAttr;
3898 }
3899 
3900 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3901  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3902  if (failed(indexingMapsAttr))
3903  return failure();
3904 
3905  if (*indexingMapsAttr == nullptr) {
3906  auto indexingMapAttrs = llvm::map_to_vector(
3907  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3908  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3909  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3910  }
3911 
3912  result.addAttribute("indexing_maps", *indexingMapsAttr);
3913  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3914  MatmulOp::getRegionBuilder());
3915 }
3916 
3917 void MatmulOp::print(OpAsmPrinter &p) {
3918  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3919  MatmulOp::getDefaultIndexingMaps(getContext()),
3920  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3921  if (!llvm::equal(getIndexingMaps(), indexingMaps))
3922  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3923 
3924  std::array<StringRef, 3> elidedAttrs = {
3925  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3926  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3927  elidedAttrs);
3928 }
3929 
3930 /// Verify the user defined indexing maps.
3931 LogicalResult MatmulOp::verify() {
3932  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3933  if (!hasUserDefinedMaps())
3934  return success();
3935 
3936  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3937  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3938  return failure();
3939  }
3940  return success();
3941 }
3942 
3943 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3944  return memref::foldMemRefCast(*this);
3945 }
3946 
3947 void MatmulOp::getEffects(
3949  &effects) {
3950  if (hasPureTensorSemantics())
3951  return;
3952  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3953 }
3954 
3955 Speculation::Speculatability MatmulOp::getSpeculatability() {
3956  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3957 }
3958 
3960 MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
3961  AffineExpr d0, d1, d2;
3962  MLIRContext *context = builder.getContext();
3963  bindDims(context, d0, d1, d2);
3964  AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
3965  AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
3966  AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
3967  return {mapLHS, mapRHS, mapOut};
3968 }
3969 
3971  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
3972  if (!maps)
3973  return false;
3974  if (maps.size() != 3)
3975  return false;
3976  auto positions = getAffineResultPositions(maps);
3977  if (failed(positions))
3978  return false;
3979  return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
3980  (*positions)[1] == SmallVector<int64_t>{2, 1} &&
3981  (*positions)[2] == SmallVector<int64_t>{0, 1};
3982 }
3983 
3985  OperationState &result,
3986  ValueRange inputs, ValueRange outputs,
3987  ArrayRef<NamedAttribute> attributes) {
3988  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
3989  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
3990 }
3991 
3994  ValueRange inputs, ValueRange outputs,
3995  ArrayRef<NamedAttribute> attributes) {
3996  OperationState state(location, getOperationName());
3997  build(builder, state, inputs, outputs, attributes);
3998  auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
3999  assert(res && "builder didn't return the right type");
4000  return res;
4001 }
4002 
4004  OperationState &result,
4005  TypeRange resultTensorTypes,
4006  ValueRange inputs, ValueRange outputs,
4007  ArrayRef<NamedAttribute> attributes) {
4008  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4009  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4010 }
4011 
4014  TypeRange resultTensorTypes, ValueRange inputs,
4015  ValueRange outputs,
4016  ArrayRef<NamedAttribute> attributes) {
4017  OperationState state(location, getOperationName());
4018  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4019  auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4020  assert(res && "builder didn't return the right type");
4021  return res;
4022 }
4023 
4025  OperationState &result,
4026  TypeRange resultTensorTypes,
4027  ValueRange inputs, ValueRange outputs,
4028  Attribute cast,
4029  ArrayRef<NamedAttribute> attributes) {
4030  result.addAttribute("cast", cast);
4031  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4032  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4033 }
4034 
4037  TypeRange resultTensorTypes, ValueRange inputs,
4038  ValueRange outputs, Attribute cast,
4039  ArrayRef<NamedAttribute> attributes) {
4040  OperationState state(location, getOperationName());
4041  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4042  auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
4043  assert(res && "builder didn't return the right type");
4044  return res;
4045 }
4046 
4048  return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4050  op->getAttr("indexing_maps"));
4051 }
4052 
4054 MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4055  AffineExpr d0, d1, d2;
4056  MLIRContext *context = builder.getContext();
4057  bindDims(context, d0, d1, d2);
4058  AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
4059  AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
4060  AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
4061  return {mapLHS, mapRHS, mapOut};
4062 }
4063 
4065  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4066  if (!maps)
4067  return false;
4068  if (maps.size() != 3)
4069  return false;
4070  auto positions = getAffineResultPositions(maps);
4071  if (failed(positions))
4072  return false;
4073  return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
4074  (*positions)[1] == SmallVector<int64_t>{1, 2} &&
4075  (*positions)[2] == SmallVector<int64_t>{0, 1};
4076 }
4077 
4079  OperationState &result,
4080  ValueRange inputs, ValueRange outputs,
4081  ArrayRef<NamedAttribute> attributes) {
4082  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4083  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4084 }
4085 
4088  ValueRange inputs, ValueRange outputs,
4089  ArrayRef<NamedAttribute> attributes) {
4090  OperationState state(location, getOperationName());
4091  build(builder, state, inputs, outputs, attributes);
4092  auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4093  assert(res && "builder didn't return the right type");
4094  return res;
4095 }
4096 
4098  OperationState &result,
4099  TypeRange resultTensorTypes,
4100  ValueRange inputs, ValueRange outputs,
4101  ArrayRef<NamedAttribute> attributes) {
4102  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4103  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4104 }
4105 
4108  TypeRange resultTensorTypes, ValueRange inputs,
4109  ValueRange outputs,
4110  ArrayRef<NamedAttribute> attributes) {
4111  OperationState state(location, getOperationName());
4112  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4113  auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4114  assert(res && "builder didn't return the right type");
4115  return res;
4116 }
4117 
4119  OperationState &result,
4120  TypeRange resultTensorTypes,
4121  ValueRange inputs, ValueRange outputs,
4122  Attribute cast,
4123  ArrayRef<NamedAttribute> attributes) {
4124  result.addAttribute("cast", cast);
4125  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4126  MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
4127 }
4128 
4131  TypeRange resultTensorTypes, ValueRange inputs,
4132  ValueRange outputs, Attribute cast,
4133  ArrayRef<NamedAttribute> attributes) {
4134  OperationState state(location, getOperationName());
4135  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4136  auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
4137  assert(res && "builder didn't return the right type");
4138  return res;
4139 }
4140 
4142  return dyn_cast_or_null<linalg::MatmulOp>(op) &&
4144  op->getAttr("indexing_maps"));
4145 }
4146 
4148 BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
4149  AffineExpr d0, d1, d2, d3;
4150  MLIRContext *context = builder.getContext();
4151  bindDims(context, d0, d1, d2, d3);
4152  AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
4153  AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
4154  AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4155  return {mapLHS, mapRHS, mapOut};
4156 }
4157 
4159  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4160  if (!maps)
4161  return false;
4162  if (maps.size() != 3)
4163  return false;
4164  auto positions = getAffineResultPositions(maps);
4165  if (failed(positions))
4166  return false;
4167  return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
4168  (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4169  (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4170 }
4171 
4173  OpBuilder &builder, OperationState &result, ValueRange inputs,
4174  ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4175  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4176  BatchMatmulOp::getRegionBuilder(),
4177  getDefaultIndexingMaps(builder));
4178 }
4179 
4182  ValueRange inputs, ValueRange outputs,
4183  ArrayRef<NamedAttribute> attributes) {
4184  OperationState state(location, getOperationName());
4185  build(builder, state, inputs, outputs, attributes);
4186  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4187  assert(res && "builder didn't return the right type");
4188  return res;
4189 }
4190 
4192  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4193  ValueRange inputs, ValueRange outputs,
4194  ArrayRef<NamedAttribute> attributes) {
4195  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4196  BatchMatmulOp::getRegionBuilder(),
4197  getDefaultIndexingMaps(builder));
4198 }
4199 
4202  TypeRange resultTensorTypes, ValueRange inputs,
4203  ValueRange outputs,
4204  ArrayRef<NamedAttribute> attributes) {
4205  OperationState state(location, getOperationName());
4206  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4207  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4208  assert(res && "builder didn't return the right type");
4209  return res;
4210 }
4211 
4213  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4214  ValueRange inputs, ValueRange outputs, Attribute cast,
4215  ArrayRef<NamedAttribute> attributes) {
4216  result.addAttribute("cast", cast);
4217  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4218  BatchMatmulOp::getRegionBuilder(),
4219  getDefaultIndexingMaps(builder));
4220 }
4221 
4224  TypeRange resultTensorTypes, ValueRange inputs,
4225  ValueRange outputs, Attribute cast,
4226  ArrayRef<NamedAttribute> attributes) {
4227  OperationState state(location, getOperationName());
4228  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4229  auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
4230  assert(res && "builder didn't return the right type");
4231  return res;
4232 }
4233 
4235  return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4237  op->getAttr("indexing_maps"));
4238 }
4239 
4241 BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
4242  AffineExpr d0, d1, d2, d3;
4243  MLIRContext *context = builder.getContext();
4244  bindDims(context, d0, d1, d2, d3);
4245  AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
4246  AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
4247  AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
4248  return {mapLHS, mapRHS, mapOut};
4249 }
4250 
4252  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4253  if (!maps)
4254  return false;
4255  if (maps.size() != 3)
4256  return false;
4257  auto positions = getAffineResultPositions(maps);
4258  if (failed(positions))
4259  return false;
4260  return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4261  (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
4262  (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4263 }
4264 
4266  OpBuilder &builder, OperationState &result, ValueRange inputs,
4267  ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
4268  buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
4269  BatchMatmulOp::getRegionBuilder(),
4270  getDefaultIndexingMaps(builder));
4271 }
4272 
4275  ValueRange inputs, ValueRange outputs,
4276  ArrayRef<NamedAttribute> attributes) {
4277  OperationState state(location, getOperationName());
4278  build(builder, state, inputs, outputs, attributes);
4279  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4280  assert(res && "builder didn't return the right type");
4281  return res;
4282 }
4283 
4285  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4286  ValueRange inputs, ValueRange outputs,
4287  ArrayRef<NamedAttribute> attributes) {
4288  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4289  BatchMatmulOp::getRegionBuilder(),
4290  getDefaultIndexingMaps(builder));
4291 }
4292 
4295  TypeRange resultTensorTypes, ValueRange inputs,
4296  ValueRange outputs,
4297  ArrayRef<NamedAttribute> attributes) {
4298  OperationState state(location, getOperationName());
4299  build(builder, state, resultTensorTypes, inputs, outputs, attributes);
4300  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4301  assert(res && "builder didn't return the right type");
4302  return res;
4303 }
4304 
4306  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
4307  ValueRange inputs, ValueRange outputs, Attribute cast,
4308  ArrayRef<NamedAttribute> attributes) {
4309  result.addAttribute("cast", cast);
4310  buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
4311  BatchMatmulOp::getRegionBuilder(),
4312  getDefaultIndexingMaps(builder));
4313 }
4314 
4317  TypeRange resultTensorTypes, ValueRange inputs,
4318  ValueRange outputs, Attribute cast,
4319  ArrayRef<NamedAttribute> attributes) {
4320  OperationState state(location, getOperationName());
4321  build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
4322  auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
4323  assert(res && "builder didn't return the right type");
4324  return res;
4325 }
4326 
4328  return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
4330  op->getAttr("indexing_maps"));
4331 }
4332 
4333 //===----------------------------------------------------------------------===//
4334 // ContractOp
4335 //===----------------------------------------------------------------------===//
4336 
4337 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
4338  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
4339  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
4340  // domains are all the same, and each implements a projected permutation.
4341  // Each iteration space dim must occur for at least one operand and either
4342  // takes part in a contraction/reduction or else has parallel iteration type.
4343  // We have that a dim is a contraction/reduction dim if and only if the dim
4344  // occurs for the output operand. We use this fact for fast inference:
4345  // NB: In case we allow dims to occur solely for one input, the above still
4346  // holds: per the einsum semantics, these are reduction dims as well.
4347  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
4348  for (auto result : outAffineMap.getResults()) {
4349  auto dimExpr = dyn_cast<AffineDimExpr>(result);
4350  assert(dimExpr && "affine_map is a projected permutation");
4351  dimsInOutput[dimExpr.getPosition()] = true;
4352  }
4353 
4354  SmallVector<utils::IteratorType> iteratorTypes;
4355  for (auto dimOccursInOutput : dimsInOutput)
4356  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
4357  : utils::IteratorType::reduction);
4358 
4359  return iteratorTypes;
4360 }
4361 
4362 unsigned ContractOp::getNumRegionArgs() { return 3; }
4363 
4364 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
4365 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4368  if (emitError && block.getNumArguments() != 3) {
4369  emitError() << "ContractOp regionBuilder expects 3 args, got "
4370  << block.getNumArguments();
4371  return;
4372  }
4373  assert(block.getNumArguments() == 3 &&
4374  "ContractOp regionBuilder expects 3 args");
4375  RegionBuilderHelper helper(b, block);
4376 
4377  TypeFn castSignedness = TypeFn::cast_signed;
4378  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4379  return attr.getName() == "cast";
4380  });
4381  if (castIter != attrs.end()) {
4382  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4383  castSignedness = attr.getValue();
4384  }
4385 
4386  // TODO: Support fields with operators besides mult & add.
4387  Type outType = block.getArgument(2).getType();
4388  Value lhsAtOutType =
4389  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
4390  Value rhsAtOutType =
4391  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
4392  Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
4393  rhsAtOutType, emitError);
4394  if (!productAtOutType)
4395  return;
4396  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
4397  productAtOutType, emitError);
4398  if (!result)
4399  return;
4400  helper.yieldOutputs({result});
4401 }
4402 
4403 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
4404  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
4405  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
4406  return parser.emitError(parser.getCurrentLocation(),
4407  "expected 'indexing_maps' attribute");
4408  result.addAttribute("indexing_maps", *indexingMapsAttr);
4409 
4410  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
4411  regionBuilder);
4412 }
4413 
4415  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4417  p, getOperation(), getInputs(), getOutputs(),
4418  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
4419 }
4420 
4421 LogicalResult ContractOp::verify() {
4422  int iterationSpaceDims = -1;
4423  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
4424  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
4425  // access an input operand (so occurrence count can be at most 2) and
4426  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
4427  SmallVector<size_t> inOccurrences;
4428  SmallVector<size_t> outOccurrences;
4429 
4430  // A helper so that for each operand's affine_map and type we check that ...
4431  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
4432  bool isInput) -> LogicalResult {
4433  // ... the affine_map is a projected permutation;
4434  if (!affineMap.isProjectedPermutation())
4435  return emitError("provided affine_map is not a projected permutation");
4436 
4437  // ... the rank of the affine_map's results and corresponding type match;
4438  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
4439  if (affineMap.getNumResults() != shapedType.getRank())
4440  return emitError("ranks of shaped operand and results of corresponding "
4441  "affine_map differ");
4442  } else if (affineMap.getNumResults() != 0) {
4443  return emitError("affine_map specifies shaped access while operand has "
4444  "non-shaped type");
4445  }
4446 
4447  // ... the rank of the affine_map's domain is the same as those seen prior;
4448  if (iterationSpaceDims == -1) {
4449  iterationSpaceDims = affineMap.getNumDims();
4450  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4451  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4452  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4453  return emitError("iteration spaces of provided affine_maps differ");
4454  }
4455 
4456  // ... update counts of dims used to access either an input or the output.
4457  for (AffineExpr affineExpr : affineMap.getResults()) {
4458  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4459  if (!affineDimExpr)
4460  llvm_unreachable("affine_map is a projected permutation");
4461 
4462  if (isInput)
4463  inOccurrences[affineDimExpr.getPosition()] += 1;
4464  else
4465  outOccurrences[affineDimExpr.getPosition()] += 1;
4466  }
4467 
4468  return success();
4469  };
4470 
4471  for (auto &&[affineMap, operandType, isInput] :
4472  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4473  SmallVector<bool>{true, true, false})) {
4474  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4475  return failure(); // NB: checkAffineMapAndType will emit relevant error.
4476  }
4477 
4478  bool hasContractingDim = false;
4479  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4480  size_t inOccCount = inOccurrences[dimIndex];
4481  size_t outOccCount = outOccurrences[dimIndex];
4482 
4483  // We have a contracting dim if and only if ...
4484  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4485 
4486  if (inOccCount == 0 && outOccCount == 0)
4487  return emitError() << "iteration space dim at index " << dimIndex
4488  << " not used to access any operand";
4489 
4490  // NB: We disallow a dim which occurs for only one input operand and not
4491  // for the output. In terms of einsum semantics such dims have a
4492  // sensible meaning - namely an additional reduction per each such dim.
4493  // By contrast, the ContractionOpInterface does not know about this
4494  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4495  // while vector.contract's verifier accepts dims of this kind many of
4496  // its lowerings give up on encountering these dims.
4497  // TODO: Remove following once we have comprehensive support for input-only
4498  // reduction dims, at both the linalg- and vector-dialect levels.
4499  if (inOccCount == 1 && outOccCount != 1)
4500  return emitError()
4501  << "iteration space dim at index " << dimIndex
4502  << " is neither a contracting dim nor of parallel iteration type";
4503  }
4504 
4505  if (!hasContractingDim)
4506  return emitError("'indexing_maps' do not specify a contracting dimension");
4507 
4508  return success();
4509 }
4510 
4511 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4512  return memref::foldMemRefCast(*this);
4513 }
4514 
4515 void ContractOp::getEffects(
4517  &effects) {
4518  if (hasPureTensorSemantics())
4519  return;
4520  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4521 }
4522 
4523 Speculation::Speculatability ContractOp::getSpeculatability() {
4524  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4525 }
4526 
4527 //===----------------------------------------------------------------------===//
4528 // Implementation of BatchMatmulOp
4529 //===----------------------------------------------------------------------===//
4531 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4532  AffineExpr d0, d1, d2, d3;
4533  SmallVector<AffineMap> indexingMaps;
4534  bindDims(context, d0, d1, d2, d3);
4535  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4536  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4537  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4538  return indexingMaps;
4539 }
4540 
4541 bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
4542  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
4543  if (!maps)
4544  return false;
4545  if (maps.size() != 3)
4546  return false;
4547  auto positions = getAffineResultPositions(maps);
4548  if (failed(positions))
4549  return false;
4550  return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
4551  (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
4552  (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
4553 }
4554 
4555 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4557  utils::IteratorType::parallel, utils::IteratorType::parallel,
4558  utils::IteratorType::parallel, utils::IteratorType::reduction};
4559 }
4560 
4561 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4562 
4563 std::string BatchMatmulOp::getLibraryCallName() {
4564  return generateLibraryCallName(getOperation());
4565 }
4566 
4567 /// Check if the op has broadcast and/or transpose semantic. Returns true if
4568 /// the user defined indexing maps are not equal to default map.
4569 bool BatchMatmulOp::hasUserDefinedMaps() {
4570  SmallVector<AffineMap, 3> defaultMaps =
4571  getDefaultIndexingMaps(this->getContext());
4572  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4573  return defaultMaps != explicitMaps;
4574 }
4575 
4576 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
4577 /// broadcast map must include K dimension.
4578 /// TODO: Strict inclusion of K dimension in the broadcast map is not
4579 /// necessary for both input matrices simultaneously. We can relax this
4580 /// condition to have K dimension for one input matrix map and infer the K
4581 /// dimension for other input matrix map from the one already having K
4582 /// dimension.
4583 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4584  assert(bcastMap.getNumResults() < 3 &&
4585  "Expected less than 3 result dim expr.");
4586  bool isValid = false;
4587  enum Indices { batchPos, mPos, nPos, kPos };
4588  if (bcastMap.getNumResults() == 1) {
4589  AffineExpr expr = bcastMap.getResult(0);
4590  isValid = expr.isFunctionOfDim(kPos);
4591  } else if (bcastMap.getNumResults() == 2) {
4592  AffineExpr expr0 = bcastMap.getResult(0);
4593  AffineExpr expr1 = bcastMap.getResult(1);
4594  isValid =
4595  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4596  expr0.isFunctionOfDim(mPos)) &&
4597  expr1.isFunctionOfDim(kPos))
4598  : ((expr0.isFunctionOfDim(batchPos) &&
4599  expr1.isFunctionOfDim(kPos)) ||
4600  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4601  }
4602  return isValid;
4603 }
4604 
4605 void BatchMatmulOp::regionBuilder(
4608  if (emitError && block.getNumArguments() != 3) {
4609  emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4610  << block.getNumArguments();
4611  return;
4612  }
4613  assert(block.getNumArguments() == 3 &&
4614  "BatchMatmulOp regionBuilder expects 3 args");
4615  RegionBuilderHelper helper(b, block);
4616  SmallVector<Value> yields;
4617 
4618  TypeFn castVal = TypeFn::cast_signed;
4619  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4620  return attr.getName() == "cast";
4621  });
4622  if (castIter != attrs.end()) {
4623  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4624  castVal = attr.getValue();
4625  }
4626 
4627  auto toType = block.getArgument(2).getType();
4628  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4629  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4630  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4631  Value addVal =
4632  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4633  yields.push_back(addVal);
4634  helper.yieldOutputs(yields);
4635 }
4636 
4637 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4638  SmallVector<Attribute, 3> indexingMapsAttr;
4639  Attribute mapAttr;
4640  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4641  if (parser.parseEqual())
4642  return failure();
4643 
4644  if (parser.parseLSquare())
4645  return failure();
4646 
4647  do {
4648  if (parser.parseAttribute(mapAttr))
4649  return failure();
4650  if (!isa<AffineMapAttr>(mapAttr)) {
4651  return parser.emitError(parser.getCurrentLocation(),
4652  "expected affine map attribute");
4653  }
4654  indexingMapsAttr.push_back(mapAttr);
4655 
4656  if (parser.parseOptionalComma())
4657  break;
4658  } while (true);
4659 
4660  if (parser.parseRSquare())
4661  return failure();
4662  }
4663  // Initialize indexingMaps, if not supplied explicitly.
4664  if (indexingMapsAttr.empty()) {
4665  indexingMapsAttr = llvm::map_to_vector(
4666  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4667  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4668  }
4669  result.addAttribute("indexing_maps",
4670  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4671 
4672  return ::parseNamedStructuredOp(parser, result,
4673  BatchMatmulOp::getNumRegionArgs(),
4674  BatchMatmulOp::getRegionBuilder());
4675 }
4676 
4678  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4679  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4680  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4681  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4682  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4683 
4684  std::array<StringRef, 3> elidedAttrs = {
4685  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4686  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4687  elidedAttrs);
4688 }
4689 
4690 /// Verify the user defined indexing maps.
4691 LogicalResult BatchMatmulOp::verify() {
4692  // Verification of pure batch_matmul is handled by
4693  // verifyStructuredOpInterface().
4694  if (!hasUserDefinedMaps())
4695  return success();
4696 
4697  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4698  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
4699  return failure();
4700  }
4701  return success();
4702 }
4703 
4704 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4706  return memref::foldMemRefCast(*this);
4707 }
4708 
4709 void BatchMatmulOp::getEffects(
4711  &effects) {
4712  if (hasPureTensorSemantics())
4713  return;
4714  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4715 }
4716 
4717 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4718  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4719 }
4720 
4721 //===----------------------------------------------------------------------===//
4722 // ElementwiseOp
4723 //===----------------------------------------------------------------------===//
4724 //
4725 namespace {
4726 struct ArityGroupAndKind {
4727  // The enum class {Unary, Binary, Ternary, ..}
4728  ElementwiseArityGroup arityGroup;
4729 
4730  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4731  union Kind {
4732  UnaryFn unaryFn;
4733  BinaryFn binaryFn;
4734  TernaryFn ternaryFn;
4735  } kind;
4736 };
4737 
4738 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4739  return static_cast<unsigned>(arityGroup);
4740 }
4741 } // namespace
4742 
4743 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4744  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4745  constexpr int lastBinary =
4746  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4747  constexpr int lastTernary =
4748  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4749 
4750  int val = static_cast<int>(kind);
4751  ArityGroupAndKind result;
4752 
4753  if (val < lastUnary) {
4754  result.arityGroup = ElementwiseArityGroup::Unary;
4755  result.kind.unaryFn = static_cast<UnaryFn>(val);
4756  return result;
4757  }
4758  if (val < lastBinary) {
4759  result.arityGroup = ElementwiseArityGroup::Binary;
4760  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4761  return result;
4762  }
4763  if (val >= lastTernary) {
4764  llvm_unreachable("unhandled ElementwiseFn");
4765  }
4766  result.arityGroup = ElementwiseArityGroup::Ternary;
4767  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4768  return result;
4769 }
4770 
4771 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4772  auto rank = getResultRank();
4773  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4774 }
4775 
4777 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4778  MLIRContext *context) {
4779  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4780  return SmallVector<AffineMap>(numMaps, map);
4781 }
4782 
4783 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4784  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4785  Attribute attr;
4786  mlir::linalg::ElementwiseKind elemwiseKindVal;
4787  if (parser.parseKeyword("kind") || parser.parseEqual())
4788  return failure();
4789 
4790  if (succeeded(parser.parseAttribute(attr))) {
4791  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4792  if (!elemwiseKindAttr)
4793  return parser.emitError(parser.getCurrentLocation(),
4794  "expected ElementwiseKind attribute");
4795  elemwiseKindVal = elemwiseKindAttr.getValue();
4796  } else {
4797  return parser.emitError(parser.getCurrentLocation(),
4798  "expected operation 'kind' attribute");
4799  }
4800  result.addAttribute(
4801  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4802 
4803  // Parse optional `indexing_maps`
4804  SmallVector<Attribute, 3> indexingMapsAttr;
4805  Attribute mapAttr;
4806  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4807  if (parser.parseEqual())
4808  return failure();
4809  if (parser.parseLSquare())
4810  return failure();
4811  do {
4812  if (parser.parseAttribute(mapAttr))
4813  return failure();
4814  if (!isa<AffineMapAttr>(mapAttr))
4815  return parser.emitError(parser.getCurrentLocation(),
4816  "expected affine map attribute");
4817  indexingMapsAttr.push_back(mapAttr);
4818  if (parser.parseOptionalComma())
4819  break;
4820  } while (true);
4821  if (parser.parseRSquare())
4822  return failure();
4823  }
4824  // At this stage of parsing the only way to infer number of region
4825  // args is through op kind, as input output tensors are not parsed yet.
4826  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4827  int numRegionArgs =
4828  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4829  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4830  ElementwiseOp::getRegionBuilder())) {
4831  return parser.emitError(parser.getCurrentLocation(),
4832  "unable to parse elemwise op");
4833  }
4834 
4835  // Initialize indexingMaps, if not supplied explicitly.
4836  if (indexingMapsAttr.empty()) {
4837  // We need to infer the numDims of the indexing maps from the output
4838  // type which is already parsed by now.
4839  auto resultType = result.operands[result.operands.size() - 1].getType();
4840  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4841  if (!shapedType)
4842  return parser.emitError(parser.getCurrentLocation(),
4843  "return type needs to be shaped type");
4844  auto numDims = shapedType.getRank();
4845  indexingMapsAttr = llvm::map_to_vector(
4846  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4847  parser.getContext()),
4848  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4849  }
4850 
4851  result.addAttribute("indexing_maps",
4852  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4853  return success();
4854 }
4855 
4857  p << " kind=";
4858  p.printAttribute(getKindAttr());
4859  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4860  "indexing_maps"};
4861  unsigned arity =
4862  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4863  unsigned numDims = getResultRank();
4864 
4865  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4866  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4867  getContext()),
4868  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4869 
4870  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4871  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4872 
4873  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4874  elidedAttrs);
4875 }
4876 
4877 LogicalResult ElementwiseOp::verify() {
4878  // All necessary checks are done either by
4879  // - EnumAttr (e.g. unknown operation kind)
4880  // - verifyStructuredOpInterface (incorrect map, sizes).
4881  return success();
4882 }
4883 
4884 /// Implements the block region builder for the ElementwiseOp. This is called by
4885 /// 'fillStructuredOpRegion'.
4886 void ElementwiseOp::regionBuilder(
4889  ElementwiseKind elemwiseKind;
4890  for (auto attr : attrs) {
4891  if (attr.getName() == b.getStringAttr("kind")) {
4892  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4893  assert(kindAttr && "op kind attribute incorrectly set");
4894  elemwiseKind = kindAttr.getValue();
4895  break;
4896  }
4897  }
4898 
4899  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4900  auto arityGroup = groupAndKind.arityGroup;
4901  auto kind = groupAndKind.kind;
4902  if (emitError && block.getNumArguments() !=
4903  getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4904  emitError() << "Elementwise regionBuilder expects "
4905  << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4906  << block.getNumArguments();
4907  return;
4908  }
4909  assert(block.getNumArguments() ==
4910  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4911  && "Elementwise regionBuilder number of block args mismatch");
4912 
4913  RegionBuilderHelper helper(b, block);
4914  SmallVector<Value> yields;
4915  Value result;
4916 
4917  if (arityGroup == ElementwiseArityGroup::Unary) {
4918  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4919 
4920  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4921  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4922  block.getArgument(1));
4923 
4924  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4925  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4926  block.getArgument(1), block.getArgument(2));
4927 
4928  } else {
4929  assert(false && "found unhandled category in elemwise");
4930  }
4931 
4932  yields.push_back(result);
4933  helper.yieldOutputs(yields);
4934 }
4935 
4936 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4938  return memref::foldMemRefCast(*this);
4939 }
4940 
4941 void ElementwiseOp::getEffects(
4943  &effects) {
4944  if (hasPureTensorSemantics())
4945  return;
4946  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4947 }
4948 
4949 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4950  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4951 }
4952 
4953 //===----------------------------------------------------------------------===//
4954 // PackOp/UnPackOp Common
4955 //===----------------------------------------------------------------------===//
4956 
4957 template <typename OpTy, typename>
4960  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4961  ? packOrUnPack.getDestType()
4962  : packOrUnPack.getSourceType();
4963  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4964  ? packOrUnPack.getSourceType()
4965  : packOrUnPack.getDestType();
4966  SmallVector<int64_t> result(
4967  packedType.getShape().take_front(unpackedType.getRank()));
4968  if (!packOrUnPack.getOuterDimsPerm().empty()) {
4970  result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4971  }
4972  return result;
4973 }
4978 
4979 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4980 // updated mixed-tile-sizes attribute. A tile size is updated only
4981 // when:
4982 // * a dim from newPackedTy is static, and
4983 // * the corresponding size from mixedTiles is still dynamic.
4984 // Otherwise, the original tile size is preserved.
4985 // Note - packed-type-dim and mixed-tile-size should always match!
4988  SmallVector<OpFoldResult> mixedTiles) {
4989  SmallVector<OpFoldResult> newMixedTileSizes;
4990  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4991  .getShape()
4992  .take_back(mixedTiles.size()),
4993  mixedTiles)) {
4994  int64_t shape = std::get<0>(it);
4995  if (shape == ShapedType::kDynamic) {
4996  newMixedTileSizes.push_back(std::get<1>(it));
4997  continue;
4998  }
4999 
5000  // If the current result dim is static, update the dynamic mixed-size
5001  // (provided the original value is dynamic).
5002  OpFoldResult tile = std::get<1>(it);
5003  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
5004  // Already a constant
5005  newMixedTileSizes.push_back(tile);
5006  } else {
5007  assert(getConstantIntValue(tile).value() == shape &&
5008  "tile size and dim size don't match!");
5009  newMixedTileSizes.push_back(
5010  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
5011  }
5012  }
5013 
5014  return newMixedTileSizes;
5015 }
5016 
5017 template <typename OpTy>
5018 static LogicalResult
5020  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5021  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5022  "applies to only pack or unpack operations");
5023  int64_t destRank = op.getDestRank();
5024  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
5025  reifiedReturnShapes[0] =
5026  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
5027  return success();
5028 }
5029 
5030 template <typename OpTy>
5032  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5033  "applies to only pack or unpack operations");
5034  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
5035  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
5036  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
5037  assert(tiles.size() == dimsToTile.size() &&
5038  "tiles must match indices of dimension to block");
5039  // bind the dimension `i` with the tile factor.
5040  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
5041  dimAndTileMapping[dimsToTile[i]] = tiles[i];
5042  return dimAndTileMapping;
5043 }
5044 
5045 template <typename OpTy>
5047  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5048  "applies to only pack or unpack operations");
5049  Builder builder(op);
5050  SmallVector<OpFoldResult> mixedInnerTiles;
5051  unsigned dynamicValIndex = 0;
5052  for (int64_t staticTile : op.getStaticInnerTiles()) {
5053  if (ShapedType::isStatic(staticTile))
5054  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
5055  else
5056  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
5057  }
5058  return mixedInnerTiles;
5059 }
5060 
5061 template <typename OpTy>
5063  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5064  "applies to only pack or unpack operations");
5065  SmallVector<Value> dynamicTiles;
5066  SmallVector<int64_t> staticTiles;
5067  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
5068  return staticTiles;
5069 }
5070 
5071 /// Returns true if `dimsPos` is invalid. It is invalid when:
5072 /// a) It contains duplicate.
5073 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
5074 /// c) The number of elements in `dimsPos` is > than `rank`.
5076  size_t rank) {
5077  size_t dimsPosSize = dimsPos.size();
5078  if (dimsPosSize > rank)
5079  return true;
5080  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
5081  if (dimsPosSize != uniqued.size())
5082  return true;
5083  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
5084  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
5085  });
5086 }
5087 
5088 template <typename OpTy>
5089 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
5090  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5091  "applies to only pack or unpack operations");
5092  Operation *op = packOrUnPack.getOperation();
5093 
5094  // Return true if we have a zero-value tile.
5095  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
5096  return llvm::any_of(tiles, isZeroInteger);
5097  };
5098 
5099  // Verify tiles. Do not allow zero tiles.
5100  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
5101  if (hasZeros(mixedTiles))
5102  return op->emitError("invalid zero tile factor");
5103 
5104  // Verify inner_dims_pos and outer_dims_perm.
5105  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
5106  ? packOrUnPack.getSourceType()
5107  : packOrUnPack.getDestType();
5108  size_t unpackedRank = unpackedType.getRank();
5109  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
5110  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
5112  return op->emitError("invalid inner_dims_pos vector");
5113  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
5114  return op->emitError("invalid outer_dims_perm vector");
5115  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
5116  return op->emitError("outer_dims_perm must be a permutation or empty");
5117 
5118  // Tiling factors must be less than or equal to the input rank for pack (or
5119  // output rank for unpack), and must match the number of `inner_dims_pos`.
5120  if (mixedTiles.size() > unpackedRank) {
5121  return op->emitError("tiling factors must be less than or equal to the "
5122  "input rank for pack or output rank for unpack");
5123  }
5124  if (mixedTiles.size() != innerDimsPos.size()) {
5125  return op->emitError(
5126  "tiling factors must equal the number of dimensions to tile");
5127  }
5128 
5129  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5130  ? packOrUnPack.getDestType()
5131  : packOrUnPack.getSourceType();
5132  size_t packedRank = packedType.getRank();
5133  // Require output rank to match input rank + number of blocking factors.
5134  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
5135  if (expectedPackedRank != packedRank) {
5136  return op->emitError(
5137  "packed rank != (unpacked rank + num tiling factors), got ")
5138  << packedRank << " != " << expectedPackedRank;
5139  }
5140 
5141  // Verify result shape is greater than the minimum expected
5142  // by the pack operation, and that the output shape
5143  // represents full tiles.
5144  RankedTensorType expectedPackedType = PackOp::inferPackedType(
5145  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
5146  if (!llvm::all_of(
5147  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
5148  mixedTiles),
5149  [](std::tuple<int64_t, OpFoldResult> it) {
5150  int64_t shape = std::get<0>(it);
5151  if (Attribute attr =
5152  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
5153  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
5154  int64_t staticTileSize = intAttr.getValue().getSExtValue();
5155  return shape == staticTileSize;
5156  }
5157  return ShapedType::isDynamic(shape);
5158  })) {
5159  return op->emitError("mismatch in inner tile sizes specified and shaped of "
5160  "tiled dimension in the packed type");
5161  }
5162  if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
5163  packedType.getShape()))) {
5164  return op->emitError("expected ")
5165  << expectedPackedType << " for the packed domain value, got "
5166  << packedType;
5167  }
5168  return success();
5169 }
5170 
5171 namespace {
5172 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
5173 /// various permutations to the op.
5174 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
5175 // these. These may or may not become true foldings / canonicalizations
5176 // depending on how aggressive we want to be in automatically folding
5177 // transposes.
5178 struct PackOrUnPackTransposeResult {
5182 };
5183 } // namespace
5184 
5185 template <typename OpTy>
5186 static PackOrUnPackTransposeResult
5188  ArrayRef<int64_t> innerPermutation,
5189  ArrayRef<int64_t> outerPermutation) {
5190  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5191  "applies to only pack or unpack operations");
5192  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
5193  "some permutation must be non-empty");
5194  PackOrUnPackTransposeResult metadata;
5195  metadata.innerDimsPos =
5196  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
5197  metadata.innerTiles =
5198  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
5199  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
5200  ? packOrUnPackOp.getSourceRank()
5201  : packOrUnPackOp.getDestRank();
5202  metadata.outerDimsPerm =
5203  packOrUnPackOp.getOuterDimsPerm().empty()
5204  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
5205  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
5206  if (!innerPermutation.empty()) {
5207  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
5208  isPermutationVector(innerPermutation) &&
5209  "invalid inner permutation");
5210  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
5211  applyPermutationToVector(metadata.innerTiles, innerPermutation);
5212  }
5213  if (!outerPermutation.empty()) {
5214  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
5215  isPermutationVector(outerPermutation) &&
5216  "invalid outer permutation");
5217  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
5218  }
5219  return metadata;
5220 }
5221 
5222 //===----------------------------------------------------------------------===//
5223 // PackOp
5224 //===----------------------------------------------------------------------===//
5225 
5226 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
5227  setNameFn(getResult(), "pack");
5228 }
5229 
5230 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
5233  std::optional<Value> paddingValue,
5235  assert(innerDimsPos.size() == innerTiles.size() &&
5236  "number of tile sizes specified must match the specified number of "
5237  "original dimensions to be tiled");
5238  SmallVector<int64_t> staticTileSizes;
5239  SmallVector<Value> dynamicTileSizes;
5240  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5241  build(builder, state, dest.getType(), source, dest,
5242  paddingValue ? *paddingValue : nullptr,
5243  outerDimsPerm.empty() ? nullptr
5245  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5246  builder.getDenseI64ArrayAttr(staticTileSizes));
5247 }
5248 
5249 LogicalResult
5251  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5252  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5253 }
5254 
5255 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
5256  return getDimAndTileMappingImpl(*this);
5257 }
5258 
5259 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
5260  return getMixedTilesImpl(*this);
5261 }
5262 
5263 SmallVector<int64_t> PackOp::getStaticTiles() {
5264  return getStaticTilesImpl(*this);
5265 }
5266 
5267 ArrayRef<int64_t> PackOp::getAllOuterDims() {
5268  ShapedType inputType = getSourceType();
5269  int64_t inputRank = inputType.getRank();
5270  return getDestType().getShape().take_front(inputRank);
5271 }
5272 
5273 SmallVector<int64_t> PackOp::getTiledOuterDims() {
5274  auto innerDimsPos = getInnerDimsPos();
5275  auto packedShape = getDestType().getShape();
5277 
5278  for (auto index : innerDimsPos)
5279  res.push_back(packedShape[index]);
5280 
5281  return res;
5282 }
5283 
5284 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
5286  ArrayRef<int64_t> outputShape,
5289  SmallVector<int64_t> outputTileSizes(
5290  outputShape.take_front(inputShape.size()));
5291  if (!outerDimsPerm.empty()) {
5292  assert(outerDimsPerm.size() == outputTileSizes.size() &&
5293  "expected output and outer_dims_perm to have same size");
5294  applyPermutationToVector(outputTileSizes,
5296  }
5297  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
5298  if (ShapedType::isDynamic(inputShape[pos]))
5299  continue;
5300  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
5301 
5302  if (!constantTile) {
5303  if (ShapedType::isStatic(outputTileSizes[pos]) &&
5304  (inputShape[pos] % outputTileSizes[pos] != 0))
5305  return true;
5306  } else if (inputShape[pos] % (*constantTile) != 0) {
5307  return true;
5308  }
5309  }
5310  return false;
5311 }
5312 
5313 LogicalResult PackOp::verify() {
5315  return failure();
5316 
5317  // Verify padding value, and bail out if the tile does not divide the
5318  // dimension fully. In the case of dynamic tile factors or dimensions, having
5319  // a partial tile is undefined behavior.
5320  auto paddingValue = getPaddingValue();
5321  if (paddingValue &&
5322  paddingValue.getType() != getSourceType().getElementType()) {
5323  return emitOpError("expected padding_value has ")
5324  << getSourceType().getElementType()
5325  << " but got: " << paddingValue.getType();
5326  }
5327 
5328  if (!paddingValue &&
5329  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
5330  getDestType().getShape(), getOuterDimsPerm(),
5331  getMixedTiles())) {
5332  return emitOpError(
5333  "invalid tile factor or output size provided. Only full tiles are "
5334  "supported when padding_value is not set");
5335  }
5336  return success();
5337 }
5338 
5339 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
5340 /// Value's to kDynamic, even if they are arith.constant values.
5341 static SmallVector<int64_t>
5343  SmallVector<int64_t> result;
5344  for (auto o : ofrs) {
5345  // Have to do this first, as getConstantIntValue special-cases constants.
5346  if (llvm::dyn_cast_if_present<Value>(o))
5347  result.push_back(ShapedType::kDynamic);
5348  else
5349  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
5350  }
5351  return result;
5352 }
5353 
5354 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
5355 /// the packed type. Having a shared helper helps implement these two methods in
5356 /// a way that ensures that they agree on which dimensions are dynamic.
5358  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
5360  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
5361  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5362  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
5363  continue;
5364  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
5365  resultShape[tiledDim.value()] = ShapedType::kDynamic;
5366  continue;
5367  }
5368  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
5369  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
5370  }
5371 
5372  // Swap tile loops if outer_dims_perm is available.
5373  if (!outerDimsPerm.empty())
5375 
5376  // Append the inner tile dimensions.
5377  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
5378  return resultShape;
5379 }
5380 
5381 SmallVector<OpFoldResult> PackOp::getResultShape(
5382  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
5385  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
5386 
5387  AffineExpr s0, s1;
5388  bindSymbols(builder.getContext(), s0, s1);
5389  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
5390  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
5391  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
5392  builder, loc, ceilDivExpr,
5393  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
5394  }
5395  if (!outerDimsPerm.empty())
5397  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
5398 
5399  SmallVector<int64_t> resultTypeShape =
5401  asShapeWithAnyValueAsDynamic(innerTileSizes),
5403 
5404  // Fix-up `resultDims` to ensure that they are Value's if and only if the
5405  // result type shape says it's a dynamic dim. This is needed as callers may
5406  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
5407  // dynamic dims returned by that.
5408  for (unsigned i = 0; i < resultDims.size(); ++i) {
5409  if (ShapedType::isStatic(resultTypeShape[i]))
5410  continue;
5411  resultDims[i] =
5412  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
5413  }
5414 
5415  return resultDims;
5416 }
5417 
5418 /// Get the expected packed type based on source type, tile factors, position of
5419 /// the inner tiles and permutation of the outer tiled loop.
5420 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
5421  ArrayRef<int64_t> innerTileSizes,
5425  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
5426  return RankedTensorType::get(resultShape, sourceType.getElementType());
5427 }
5428 
5429 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
5430  ArrayRef<OpFoldResult> innerTileSizes,
5433  AffineExpr dim0, dim1;
5434  bindDims(b.getContext(), dim0, dim1);
5435  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5436  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
5437  {v1, v2});
5438  };
5439 
5440  SmallVector<OpFoldResult> mixedSizes;
5441  for (auto [index, value] : llvm::enumerate(
5442  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
5443  if (ShapedType::isDynamic(value))
5444  mixedSizes.push_back(
5445  tensor::DimOp::create(b, loc, source, index).getResult());
5446  else
5447  mixedSizes.push_back(b.getIndexAttr(value));
5448  }
5449  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5450  int64_t dimPos = std::get<0>(it);
5451  OpFoldResult tileSize = std::get<1>(it);
5452  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5453  }
5454  if (!outerDimsPerm.empty())
5455  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5456 
5457  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5458  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5459  return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5460 }
5461 
5462 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5463  ArrayRef<int64_t> innerPermutation,
5464  ArrayRef<int64_t> outerPermutation) {
5465  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5466  *this, innerPermutation, outerPermutation);
5467  Value transposedDest =
5468  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5469  metadata.innerDimsPos, metadata.outerDimsPerm);
5470  return PackOp::create(b, loc, getSource(), transposedDest,
5471  metadata.innerDimsPos, metadata.innerTiles,
5472  getPaddingValue(), metadata.outerDimsPerm);
5473 }
5474 
5475 /// Returns true if the tiles and the tiled dims are constant.
5476 template <typename OpTy>
5477 static bool areTilesAndTiledDimsAllConstant(OpTy op) {
5478  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5479  "applies to only pack or unpack operations");
5480  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5481  ? op.getDestType()
5482  : op.getSourceType();
5483  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5484  for (auto [dimDest, tile] : llvm::zip(
5485  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5486  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5487  if (!constTileSize || ShapedType::isDynamic(dimDest))
5488  return false;
5489  }
5490  return true;
5491 }
5492 
5493 Speculation::Speculatability PackOp::getSpeculatability() {
5494  if (getPaddingValue())
5496 
5497  // The verifier rejects already operations if we can statically prove that the
5498  // sizes of the tiles do not divide perfectly the dimension; thus, check only
5499  // to have constant tiles and tiled inner dimensions.
5500  if (!areTilesAndTiledDimsAllConstant(*this))
5502 
5504 }
5505 
5506 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5507 // dimensions for pack and unpack.
5508 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5509  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5510  return false;
5511  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5512  return true;
5513  // Outer dims permutation is optional.
5514  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5515  // identity permutation.
5516  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5517  isIdentityPermutation(unPackOp.getOuterDimsPerm());
5518 }
5519 
5520 // Return true if pack and unpack have the same tiles.
5521 // Same SSA values or same integer constants.
5522 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5523  auto packTiles = packOp.getMixedTiles();
5524  auto unPackTiles = unPackOp.getMixedTiles();
5525  if (packTiles.size() != unPackTiles.size())
5526  return false;
5527  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5528  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5529  return false;
5530  }
5531  return true;
5532 }
5533 
5534 /// Returns true if the pack op does not need a padding value.
5535 static bool paddingIsNotNeeded(PackOp op) {
5536  auto srcType = op.getSourceType();
5537  if (llvm::any_of(op.getInnerDimsPos(),
5538  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5539  return false;
5540  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5541  return false;
5542  return !PackOp::requirePaddingValue(
5543  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5544  op.getOuterDimsPerm(), op.getMixedTiles());
5545 }
5546 
5547 /// Returns true if the `srcShape` or `destShape` is different from the one in
5548 /// `packOp` and populates each with the inferred static shape.
5549 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5550  SmallVectorImpl<int64_t> &destShape) {
5551  bool changeNeeded = false;
5552  srcShape.assign(packOp.getSourceType().getShape().begin(),
5553  packOp.getSourceType().getShape().end());
5554  destShape.assign(packOp.getDestType().getShape().begin(),
5555  packOp.getDestType().getShape().end());
5556  llvm::SmallSetVector<int64_t, 4> innerDims;
5557  innerDims.insert_range(packOp.getInnerDimsPos());
5558  SmallVector<int64_t> inverseOuterDimsPerm;
5559  if (!packOp.getOuterDimsPerm().empty())
5560  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5561  int srcRank = packOp.getSourceRank();
5562  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5563  if (innerDims.contains(i))
5564  continue;
5565  int64_t srcPos = i;
5566  int64_t destPos = i;
5567  if (!inverseOuterDimsPerm.empty())
5568  destPos = inverseOuterDimsPerm[srcPos];
5569  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5570  ShapedType::isDynamic(destShape[destPos])) {
5571  continue;
5572  }
5573  int64_t size = srcShape[srcPos];
5574  if (ShapedType::isDynamic(size))
5575  size = destShape[destPos];
5576  srcShape[srcPos] = size;
5577  destShape[destPos] = size;
5578  changeNeeded = true;
5579  }
5580  return changeNeeded;
5581 }
5582 
5583 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5584  // Fold an pack(unpack(x)) to x.
5585  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5586  if (unPackOp.getSourceType() != packOp.getDestType())
5587  return failure();
5588  if (packOp.getPaddingValue() ||
5589  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5590  !haveSameTiles(packOp, unPackOp))
5591  return failure();
5592  rewriter.replaceOp(packOp, unPackOp.getSource());
5593  return success();
5594  }
5595 
5596  // Fold optional PaddingValue operand away if padding is not needed.
5597  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5598  rewriter.startOpModification(packOp);
5599  packOp.getPaddingValueMutable().clear();
5600  rewriter.finalizeOpModification(packOp);
5601  return success();
5602  }
5603 
5604  // Insert tensor.cast ops if static shape inference is available..
5605  SmallVector<int64_t> srcShape, destShape;
5606  if (inferStaticShape(packOp, srcShape, destShape)) {
5607  Location loc = packOp.getLoc();
5608  Value source = packOp.getSource();
5609  if (srcShape != packOp.getSourceType().getShape()) {
5610  auto newSrcType = packOp.getSourceType().clone(srcShape);
5611  source =
5612  tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5613  }
5614  Value dest = packOp.getDest();
5615  RankedTensorType originalResultType = packOp.getDestType();
5616  bool needUpdateDestType = (destShape != originalResultType.getShape());
5617  if (needUpdateDestType) {
5618  auto newDestType = packOp.getDestType().clone(destShape);
5619  dest =
5620  tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5621  }
5622  rewriter.modifyOpInPlace(packOp, [&] {
5623  packOp.getSourceMutable().assign(source);
5624  packOp.getDestMutable().assign(dest);
5625  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5626  });
5627  // Insert a cast if needed
5628  if (needUpdateDestType) {
5629  rewriter.setInsertionPointAfter(packOp);
5630  auto castOp =
5631  tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5632  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5633  }
5634  return success();
5635  }
5636 
5637  return failure();
5638 }
5639 
5640 template <typename PackOrUnpackOp>
5641 static bool isLikePadUnPad(PackOrUnpackOp packOp,
5642  RankedTensorType packedTensorType) {
5643  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5644  std::is_same<PackOrUnpackOp, UnPackOp>::value,
5645  "Function meant for pack/unpack");
5646  // This is a pad if packing only adds ones and we don't transpose dimensions.
5647 
5648  // Check that we are not transposing any dimensions.
5649  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5650  int64_t numPackedDims = innerDimsPos.size();
5651  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5652  if (orderedDims != innerDimsPos) {
5653  // Dimensions don't happen in order.
5654  return false;
5655  }
5656 
5657  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5658  int64_t packedRank = packedTensorType.getRank();
5659  // At this point we know that we are taking numPackedDims outer
5660  // dimensions and pushing them all the way as the inner most dimensions.
5661  // What's left on the outer most dimensions is, in this order:
5662  // - the factor of the packed dimensions, then
5663  // - the untouched dimensions
5664  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5665  // if all the dimensions that bubble outerward are ones.
5666  // Therefore check that all the dimensions but the numPackedDims inner most
5667  // ones are ones.
5668  return llvm::all_of(
5669  llvm::seq<int64_t>(0, packedRank - numPackedDims),
5670  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5671 }
5672 
5673 bool PackOp::isLikePad() {
5674  auto packedTensorType =
5675  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5676  return isLikePadUnPad(*this, packedTensorType);
5677 }
5678 
5679 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5680  std::optional<Attribute> paddingValue;
5681  if (auto pad = adaptor.getPaddingValue())
5682  paddingValue = pad;
5683  if (OpFoldResult reshapedSource = reshapeConstantSource(
5684  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5685  getDestType(), paddingValue))
5686  return reshapedSource;
5687  return {};
5688 }
5689 
5690 /// Folds a tensor.cast op into a consuming PackOp op if the
5691 /// `tensor.cast` has source that is more static than the consuming op.
5692 ///
5693 /// Example:
5694 /// ```mlir
5695 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5696 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5697 /// ```
5698 ///
5699 /// folds into:
5700 ///
5701 /// ```mlir
5702 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5703 /// ```
5704 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5706 
5707  LogicalResult matchAndRewrite(PackOp op,
5708  PatternRewriter &rewriter) const override {
5710  return failure();
5711 
5712  SmallVector<Type> newResultTypes(op->getResultTypes());
5713  SmallVector<Value> newOperands =
5715 
5716  // Get the updated mixed-tile-sizes attribute.
5717  SmallVector<OpFoldResult> newMixedTileSizes =
5718  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5719 
5720  // Clone op.
5721  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5722  // this point. However, in practice, we use them for things that we'd like
5723  // to preserve. Implement a better abstraction.
5724  PackOp newOp =
5725  PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5726  op.getInnerDimsPos(), newMixedTileSizes,
5727  op.getPaddingValue(), op.getOuterDimsPerm());
5728  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5729 
5730  // Replace op.
5731  Value oldResult = op.getResult();
5732  Value newResult = newOp.getResult();
5733  Value replacement =
5734  (newResult.getType() != oldResult.getType())
5735  ? tensor::CastOp::create(rewriter, op->getLoc(),
5736  oldResult.getType(), newResult)
5737  : newResult;
5738 
5739  rewriter.replaceOp(op, {replacement});
5740 
5741  return success();
5742  }
5743 };
5744 
5745 //===----------------------------------------------------------------------===//
5746 // UnPackOp
5747 //===----------------------------------------------------------------------===//
5748 
5749 void UnPackOp::getAsmResultNames(
5750  function_ref<void(Value, StringRef)> setNameFn) {
5751  setNameFn(getResult(), "unpack");
5752 }
5753 
5754 LogicalResult
5756  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5757  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5758 }
5759 
5760 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5761  return getDimAndTileMappingImpl(*this);
5762 }
5763 
5764 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5765  return getMixedTilesImpl(*this);
5766 }
5767 
5768 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5769  return getStaticTilesImpl(*this);
5770 }
5771 
5772 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5773  ShapedType destType = getDestType();
5774  int64_t destRank = destType.getRank();
5775  return getSourceType().getShape().take_front(destRank);
5776 }
5777 
5778 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5779  auto innerDimsPos = getInnerDimsPos();
5780  SmallVector<int64_t> outerDims(getAllOuterDims());
5782 
5783  // Recover the original order of the outer dims.
5784  SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
5785  invertPermutationVector(outerDimPermInv);
5786  if (!outerDimPermInv.empty())
5787  applyPermutationToVector(outerDims, outerDimPermInv);
5788 
5789  // Collect the outer dims corresponding to the tilled inner dims.
5790  for (auto index : innerDimsPos)
5791  res.push_back(outerDims[index]);
5792 
5793  return res;
5794 }
5795 
5796 LogicalResult UnPackOp::verify() {
5797  return commonVerifierPackAndUnPackOp(*this);
5798 }
5799 
5800 Speculation::Speculatability UnPackOp::getSpeculatability() {
5801  // See PackOp::getSpeculatability.
5802  if (!areTilesAndTiledDimsAllConstant(*this))
5804 
5806 }
5807 
5808 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5812  assert(innerDimsPos.size() == innerTiles.size() &&
5813  "number of tile sizes specified must match the specified number of "
5814  "original dimensions to be tiled");
5815  SmallVector<int64_t> staticTileSizes;
5816  SmallVector<Value> dynamicTileSizes;
5817  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5818  build(builder, state, dest.getType(), source, dest,
5819  outerDimsPerm.empty() ? nullptr
5821  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5822  builder.getDenseI64ArrayAttr(staticTileSizes));
5823 }
5824 
5825 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5826  Value source,
5827  ArrayRef<OpFoldResult> innerTileSizes,
5830  AffineExpr sym0, sym1;
5831  bindSymbols(b.getContext(), sym0, sym1);
5832  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5833  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5834  };
5835 
5836  SmallVector<OpFoldResult> mixedSizes;
5837  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5838  for (auto i :
5839  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5840  if (srcType.isDynamicDim(i))
5841  mixedSizes.push_back(
5842  tensor::DimOp::create(b, loc, source, i).getResult());
5843  else
5844  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5845  }
5846  if (!outerDimsPerm.empty()) {
5847  applyPermutationToVector<OpFoldResult>(
5848  mixedSizes, invertPermutationVector(outerDimsPerm));
5849  }
5850 
5851  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5852  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5853 
5854  auto elemType = srcType.getElementType();
5855  return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5856 }
5857 
5858 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5859  Value transposedSource,
5860  ArrayRef<int64_t> innerPermutation,
5861  ArrayRef<int64_t> outerPermutation) {
5862  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5863  *this, innerPermutation, outerPermutation);
5864  return UnPackOp::create(b, loc, transposedSource, getDest(),
5865  metadata.innerDimsPos, metadata.innerTiles,
5866  metadata.outerDimsPerm);
5867 }
5868 
5869 /// Returns true if the `srcShape` or `destShape` is different from the one in
5870 /// `op` and populates each with the inferred static shape.
5871 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5872  SmallVectorImpl<int64_t> &destShape) {
5873  bool changeNeeded = false;
5874  srcShape.assign(op.getSourceType().getShape().begin(),
5875  op.getSourceType().getShape().end());
5876  destShape.assign(op.getDestType().getShape().begin(),
5877  op.getDestType().getShape().end());
5878  llvm::SmallSetVector<int64_t, 4> innerDims;
5879  innerDims.insert_range(op.getInnerDimsPos());
5880  SmallVector<int64_t> inverseOuterDimsPerm;
5881  if (!op.getOuterDimsPerm().empty())
5882  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5883  int destRank = op.getDestRank();
5884  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5885  if (innerDims.contains(i))
5886  continue;
5887  int64_t srcPos = i;
5888  int64_t destPos = i;
5889  if (!inverseOuterDimsPerm.empty())
5890  srcPos = inverseOuterDimsPerm[destPos];
5891  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5892  ShapedType::isDynamic(destShape[destPos])) {
5893  continue;
5894  }
5895  int64_t size = srcShape[srcPos];
5896  if (ShapedType::isDynamic(size))
5897  size = destShape[destPos];
5898  srcShape[srcPos] = size;
5899  destShape[destPos] = size;
5900  changeNeeded = true;
5901  }
5902  return changeNeeded;
5903 }
5904 
5905 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5906  PatternRewriter &rewriter) {
5907  /// unpack(pack(x)) -> x
5908  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5909  if (packOp.getSourceType() != unPackOp.getDestType())
5910  return failure();
5911  if (packOp.getPaddingValue() ||
5912  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5913  !haveSameTiles(packOp, unPackOp))
5914  return failure();
5915  rewriter.replaceOp(unPackOp, packOp.getSource());
5916  return success();
5917  }
5918  /// unpack(destinationStyleOp(x)) -> unpack(x)
5919  if (auto dstStyleOp =
5920  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5921  auto destValue = cast<OpResult>(unPackOp.getDest());
5922  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5923  rewriter.modifyOpInPlace(unPackOp,
5924  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5925  return success();
5926  }
5927  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5928  if (unPackOp->hasOneUse()) {
5929  auto extractSliceUser =
5930  dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5931  if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5932  OpBuilder::InsertionGuard g(rewriter);
5933  rewriter.setInsertionPoint(unPackOp);
5934  auto newDest = tensor::ExtractSliceOp::create(
5935  rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5936  extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5937  extractSliceUser.getMixedStrides());
5938  rewriter.modifyOpInPlace(unPackOp, [&]() {
5939  unPackOp.setDpsInitOperand(0, newDest);
5940  unPackOp.getResult().setType(newDest.getType());
5941  });
5942  rewriter.replaceOp(extractSliceUser, unPackOp);
5943  return success();
5944  }
5945  }
5946 
5947  // Insert tensor.cast ops if static shape inference is available..
5948  SmallVector<int64_t> srcShape, destShape;
5949  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5950  Location loc = unPackOp.getLoc();
5951  Value source = unPackOp.getSource();
5952  if (srcShape != unPackOp.getSourceType().getShape()) {
5953  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5954  source = tensor::CastOp::create(rewriter, loc, newSrcType,
5955  unPackOp.getSource());
5956  }
5957  Value dest = unPackOp.getDest();
5958  if (destShape != unPackOp.getDestType().getShape()) {
5959  auto newDestType = unPackOp.getDestType().clone(destShape);
5960  dest = tensor::CastOp::create(rewriter, loc, newDestType,
5961  unPackOp.getDest());
5962  }
5963  Value newOp = UnPackOp::create(
5964  rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5965  unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
5966  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5967  unPackOp, unPackOp.getResult().getType(), newOp);
5968  return success();
5969  }
5970 
5971  return failure();
5972 }
5973 
5974 bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5975  // Rank-reduced folding is not supported.
5976  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5977  return false;
5978  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
5979  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
5980  return false;
5981  RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
5982  SmallVector<int64_t> outerShapeWithoutTranspose =
5984  for (auto [pos, tileSize] :
5985  llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5986  if (unpackedTypeAfterFold.isDynamicDim(pos))
5987  return false;
5988  if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5989  return false;
5990  if (ShapedType::isDynamic(tileSize))
5991  return false;
5992  int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5993  unpackedTypeAfterFold.getDimSize(pos);
5994  if (paddingSize >= tileSize)
5995  return false;
5996  }
5997  return true;
5998 }
5999 
6000 bool UnPackOp::isLikeUnPad() {
6001  RankedTensorType packedTensorType = getSourceType();
6002  return isLikePadUnPad(*this, packedTensorType);
6003 }
6004 
6005 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
6006  if (OpFoldResult reshapedSource = reshapeConstantSource(
6007  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
6008  getResult().getType()))
6009  return reshapedSource;
6010  return {};
6011 }
6012 
6013 /// Folds a tensor.cast op into a consuming UnPackOp op if the
6014 /// `tensor.cast` has source that is more static than the consuming op.
6015 ///
6016 /// Example:
6017 /// ```mlir
6018 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
6019 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
6020 /// ```
6021 ///
6022 /// folds into:
6023 ///
6024 /// ```mlir
6025 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
6026 /// ```
6027 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
6029 
6030  LogicalResult matchAndRewrite(UnPackOp op,
6031  PatternRewriter &rewriter) const override {
6033  return failure();
6034 
6035  SmallVector<Type> newResultTypes(op->getResultTypes());
6036  SmallVector<Value> newOperands =
6038  Value sourceTensor = newOperands[0];
6039 
6040  // Get the updated mixed-tile-sizes attribute.
6041  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
6042  rewriter, sourceTensor.getType(), op.getMixedTiles());
6043 
6044  // Clone op.
6045  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
6046  // this point. However, in practice, we use them for things that we'd like
6047  // to preserve. Implement a better abstraction.
6048  UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
6049  newOperands[1], op.getInnerDimsPos(),
6050  newMixedTileSizes, op.getOuterDimsPerm());
6051  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
6052 
6053  // Replace op.
6054  Value oldResult = op.getResult();
6055  Value newResult = newOp.getResult();
6056  Value replacement =
6057  (newResult.getType() != oldResult.getType())
6058  ? tensor::CastOp::create(rewriter, op->getLoc(),
6059  oldResult.getType(), newResult)
6060  : newResult;
6061 
6062  rewriter.replaceOp(op, {replacement});
6063 
6064  return success();
6065  }
6066 };
6067 
6068 //===----------------------------------------------------------------------===//
6069 // BatchReduceMatmulOp
6070 //===----------------------------------------------------------------------===//
6071 SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
6073  utils::IteratorType::reduction, utils::IteratorType::parallel,
6074  utils::IteratorType::parallel, utils::IteratorType::reduction};
6075 }
6076 
6078 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
6079  AffineExpr d0, d1, d2, d3;
6080  SmallVector<AffineMap> indexingMaps;
6081  bindDims(context, d0, d1, d2, d3);
6082  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
6083  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
6084  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
6085  return indexingMaps;
6086 }
6087 
6088 bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
6089  ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
6090  if (!maps)
6091  return false;
6092  if (maps.size() != 3)
6093  return false;
6094  auto positions = getAffineResultPositions(maps);
6095  if (failed(positions))
6096  return false;
6097  return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
6098  (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
6099  (*positions)[2] == SmallVector<int64_t>{1, 2};
6100 }
6101 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
6102 
6103 std::string BatchReduceMatmulOp::getLibraryCallName() {
6104  return generateLibraryCallName(getOperation());
6105 }
6106 
6107 /// Check if the op has broadcast and/or transpose semantic. Returns true if
6108 /// the user defined indexing maps are not equal to default map.
6109 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
6110  SmallVector<AffineMap, 3> defaultMaps =
6111  getDefaultIndexingMaps(this->getContext());
6112  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
6113  return defaultMaps != explicitMaps;
6114 }
6115 
6116 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
6117 /// broadcast map must include K dimension.
6118 /// TODO: Strict inclusion of K dimension in the broadcast map is not
6119 /// necessary for both input matrices simultaneously. We can relax this
6120 /// condition to have K dimension for one input matrix map and infer the K
6121 /// dimension for other input matrix map from the one already having K
6122 /// dimension.
6123 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
6124  bool isLHS) {
6125  assert(bcastMap.getNumResults() < 3 &&
6126  "Expected less than 3 result dim expr.");
6127  bool isValid = false;
6128  enum Indices { batchPos, mPos, nPos, kPos };
6129  if (bcastMap.getNumResults() == 1) {
6130  AffineExpr expr = bcastMap.getResult(0);
6131  isValid = expr.isFunctionOfDim(kPos);
6132  } else if (bcastMap.getNumResults() == 2) {
6133  AffineExpr expr0 = bcastMap.getResult(0);
6134  AffineExpr expr1 = bcastMap.getResult(1);
6135  isValid =
6136  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
6137  expr0.isFunctionOfDim(mPos)) &&
6138  expr1.isFunctionOfDim(kPos))
6139  : ((expr0.isFunctionOfDim(batchPos) &&
6140  expr1.isFunctionOfDim(kPos)) ||
6141  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
6142  }
6143  return isValid;
6144 }
6145 
6146 void BatchReduceMatmulOp::regionBuilder(
6149  if (emitError && block.getNumArguments() != 3) {
6150  emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
6151  << block.getNumArguments();
6152  return;
6153  }
6154  assert(block.getNumArguments() == 3 &&
6155  "BatchReduceMatmulOp regionBuilder expects 3 args");
6156  RegionBuilderHelper helper(b, block);
6157  SmallVector<Value> yields;
6158 
6159  auto toType = block.getArgument(2).getType();
6160  Value castValA =
6161  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
6162  Value castValB =
6163  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
6164  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
6165  Value addVal =
6166  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
6167  yields.push_back(addVal);
6168  helper.yieldOutputs(yields);
6169 }
6170 
6171 ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
6172  OperationState &result) {
6173  SmallVector<Attribute, 3> indexingMapsAttr;
6174  Attribute mapAttr;
6175  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
6176  if (parser.parseEqual())
6177  return failure();
6178  if (parser.parseLSquare())
6179  return failure();
6180 
6181  do {
6182  if (parser.parseAttribute(mapAttr))
6183  return failure();
6184  if (!isa<AffineMapAttr>(mapAttr)) {
6185  return parser.emitError(parser.getCurrentLocation(),
6186  "expected affine map attribute");
6187  }
6188  indexingMapsAttr.push_back(mapAttr);
6189 
6190  if (parser.parseOptionalComma())
6191  break;
6192  } while (true);
6193 
6194  if (parser.parseRSquare())
6195  return failure();
6196  }
6197  // Initialize indexingMaps, if not supplied explicitly.
6198  if (indexingMapsAttr.empty()) {
6199  indexingMapsAttr = llvm::map_to_vector(
6200  BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
6201  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6202  }
6203  result.addAttribute("indexing_maps",
6204  parser.getBuilder().getArrayAttr(indexingMapsAttr));
6205  return ::parseNamedStructuredOp(parser, result,
6206  BatchReduceMatmulOp::getNumRegionArgs(),
6207  BatchReduceMatmulOp::getRegionBuilder());
6208 }
6209 
6211  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
6212  BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
6213  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
6214 
6215  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
6216  p << " indexing_maps = [";
6217  llvm::interleaveComma(getIndexingMaps(), p,
6218  [&](Attribute attr) { p.printAttribute(attr); });
6219  p << "]";
6220  }
6221 
6222  SmallVector<StringRef, 3> elidedAttrs = {
6223  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
6224  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
6225  elidedAttrs);
6226 }
6227 
6228 /// Verify the user defined indexing maps.
6229 LogicalResult BatchReduceMatmulOp::verify() {
6230  // Verification of pure batch_reduce_matmul is handled by
6231  // verifyStructuredOpInterface().
6232  if (!hasUserDefinedMaps())
6233  return success();
6234 
6235  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
6236  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
6237  return failure();
6238  }
6239  return success();
6240 }
6241 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
6243  return memref::foldMemRefCast(*this);
6244 }
6245 void BatchReduceMatmulOp::getEffects(
6247  &effects) {
6248  if (hasPureTensorSemantics())
6249  return;
6250  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
6251 }
6252 
6253 Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
6254  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
6255 }
6256 
6257 } // namespace linalg
6258 } // namespace mlir
6259 
6260 //===----------------------------------------------------------------------===//
6261 // LinalgDialect
6262 //===----------------------------------------------------------------------===//
6263 
6264 void LinalgDialect::getCanonicalizationPatterns(
6265  RewritePatternSet &results) const {
6266  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
6267  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
6268 }
6269 
6271  Attribute value, Type type,
6272  Location loc) {
6273  return arith::ConstantOp::materialize(builder, value, type, loc);
6274 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3603
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:125
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1954
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2935
union mlir::linalg::@1244::ArityGroupAndKind::Kind kind
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:3007
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3581
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:5181
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2473
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:5180
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3596
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:317
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1787
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1835
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2980
static void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1616
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:189
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
Definition: LinalgOps.cpp:3667
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:160
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4734
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1447
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:206
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3629
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2957
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4728
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1328
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1295
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2366
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:5179
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:361
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:1070
static bool canUseShortForm(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1573
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:354
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:57
UnaryFn unaryFn
Definition: LinalgOps.cpp:4732
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1501
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:223
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:392
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3708
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:399
BinaryFn binaryFn
Definition: LinalgOps.cpp:4733
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:243
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
Definition: LinalgOps.cpp:329
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:314
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:162
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:166
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:386
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:111
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:261
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:24
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:623
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:429
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:456
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:412
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:845
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:700
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:628
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:612
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
Specialization of linalg.batch_matmul op that has a transpose map on A.
Definition: Linalg.h:252
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Definition: LinalgOps.cpp:4158
static bool classof(Operation *op)
Definition: LinalgOps.cpp:4234
static BatchMatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Definition: LinalgOps.cpp:4181
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
Definition: LinalgOps.cpp:4172
Specialization of linalg.batch_matmul op that has a transpose map on B.
Definition: Linalg.h:299
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
Definition: LinalgOps.cpp:4265
static bool classof(Operation *op)
Definition: LinalgOps.cpp:4327
static BatchMatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Definition: LinalgOps.cpp:4274
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Definition: LinalgOps.cpp:4251
Specialization of linalg.matmul op that has a transpose map on A.
Definition: Linalg.h:158
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Definition: LinalgOps.cpp:3970
static MatmulTransposeAOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Definition: LinalgOps.cpp:3993
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose A matmul.
Definition: LinalgOps.cpp:3984
static bool classof(Operation *op)
Definition: LinalgOps.cpp:4047
Specialization of linalg.matmul op that has a transpose map on B.
Definition: Linalg.h:205
static void build(OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Build a transpose B matmul.
Definition: LinalgOps.cpp:4078
static MatmulTransposeBOp create(OpBuilder &builder, Location location, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes={})
Definition: LinalgOps.cpp:4087
static bool isDefaultIndexingMaps(Attribute attr)
Checks if the affine map is the expected one for this operation.
Definition: LinalgOps.cpp:4064
static bool classof(Operation *op)
Definition: LinalgOps.cpp:4141
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2769
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:5019
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
static FailureOr< SmallVector< SmallVector< int64_t > > > getAffineResultPositions(ArrayAttr maps)
Definition: LinalgOps.cpp:3764
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:5549
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:5641
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:5342
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:5075
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:5062
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:5357
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2465
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4743
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:5187
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:104
static bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:5477
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2506
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:5535
static FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3883
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2445
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2456
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:5031
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4987
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5522
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:95
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5508
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Definition: LinalgOps.cpp:3748
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:5046
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
Definition: LinalgOps.cpp:3738
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:5089
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
Definition: LinalgOps.cpp:4959
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:360
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:353
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:369
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1285
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold back-to-back broadcasts together.
Definition: LinalgOps.cpp:2308
LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2311
Fold transpose with transpose.
Definition: LinalgOps.cpp:2097
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2100
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2122
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2125
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5704
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5707
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:6027
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:6030