MLIR  22.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
27 #include "mlir/IR/AffineMap.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/PatternMatch.h"
37 
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/STLExtras.h"
40 #include "llvm/ADT/SetOperations.h"
41 #include "llvm/ADT/SmallVector.h"
42 #include "llvm/ADT/StringSet.h"
43 #include "llvm/ADT/TypeSwitch.h"
44 #include "llvm/Support/FormatVariadic.h"
45 #include "llvm/Support/InterleavedRange.h"
46 #include "llvm/Support/LogicalResult.h"
47 #include "llvm/Support/MathExtras.h"
48 #include "llvm/Support/raw_ostream.h"
49 #include <cassert>
50 #include <optional>
51 
52 using namespace mlir;
53 using namespace mlir::linalg;
54 
55 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
57  int64_t dim) {
58  auto type = cast<ShapedType>(v.getType());
59  if (!type.isDynamicDim(dim))
60  return builder.getIndexAttr(type.getDimSize(dim));
61 
62  return getAsOpFoldResult(
64  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
65  return tensor::DimOp::create(builder, loc, v, dim);
66  })
67  .Case<MemRefType>([&](MemRefType t) -> Value {
68  return memref::DimOp::create(builder, loc, v, dim);
69  }));
70 }
71 
72 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
73 /// `source`.
74 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
75  ArrayRef<OpFoldResult> offsets,
77  ArrayRef<OpFoldResult> strides) {
78  return TypeSwitch<Type, Operation *>(source.getType())
79  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
80  return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes,
81  strides);
82  })
83  .Case<MemRefType>([&](MemRefType type) -> Operation * {
84  return memref::SubViewOp::create(b, loc, source, offsets, sizes,
85  strides);
86  })
87  .Default([&](Type t) -> Operation * { return nullptr; });
88 }
89 
90 //===----------------------------------------------------------------------===//
91 // Helper functions
92 //===----------------------------------------------------------------------===//
93 
95  int64_t dim) {
96  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
97  return b.createOrFold<memref::DimOp>(loc, source, dim);
98  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
99  return b.createOrFold<tensor::DimOp>(loc, source, dim);
100  llvm_unreachable("Expected MemRefType or TensorType");
101 }
102 
104  int64_t dim) {
105  auto shapedType = llvm::cast<ShapedType>(source.getType());
106  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
107  return createOrFoldDimOp(b, loc, source, dim);
108  return b.getIndexAttr(shapedType.getDimSize(dim));
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Support for named Linalg ops defined in ods-gen.
113 //===----------------------------------------------------------------------===//
114 
118 
119 /// Fills the region of a structured operation using the provided
120 /// `regionBuilder`. The method is used by both named structured ops created by
121 /// ods-gen and by manually defined C++ ops. It is called by both builders and
122 /// parsers and creates a block with arguments corresponding to the elemental
123 /// types of `inputTypes` and `outputTypes`.
124 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
125  TypeRange inputTypes, TypeRange outputTypes,
128  RegionBuilderFn regionBuilder) {
129  SmallVector<Type, 8> argTypes;
130  SmallVector<Location, 8> argLocs;
131  for (auto containers : {inputTypes, outputTypes}) {
132  for (auto t : containers) {
133  argTypes.push_back(
134  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
135 
136  // TODO: Pass in a proper location here.
137  argLocs.push_back(opBuilder.getUnknownLoc());
138  }
139  }
140 
141  // RAII.
142  OpBuilder::InsertionGuard guard(opBuilder);
143  Block *body =
144  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
145 
146  opBuilder.setInsertionPointToStart(body);
147  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
148  regionBuilder(b, *body, attrs, emitError);
149 
150  // indexing_maps is an auto-generated method.
151 
152  // iterator_types is an auto-generated method.
153 }
154 
155 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
156 /// The result types are derived automatically if `resultTensorTypes` is none.
157 /// The body of the operation is filled using `regionBuilder`. All ods-gen
158 /// created structured operations use the method to implement their builders.
160  std::optional<TypeRange> resultTensorTypes,
161  ValueRange inputs, ValueRange outputs,
162  ArrayRef<NamedAttribute> attributes,
163  RegionBuilderFn regionBuilder) {
164  // Derive the result types if needed.
165  SmallVector<Type> derivedResultTypes =
166  resultTensorTypes.value_or(TypeRange());
167  if (!resultTensorTypes)
168  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
169  llvm::IsaPred<RankedTensorType>);
170 
171  state.addOperands(inputs);
172  state.addOperands(outputs);
173  state.addTypes(derivedResultTypes);
174 
175  state.addAttributes(attributes);
176  state.addAttribute(
177  "operandSegmentSizes",
178  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
179  static_cast<int32_t>(outputs.size())}));
180 
181  // Create and fill the region of the structured operation.
182  Region &region = *state.addRegion();
183  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
184  state.attributes.getAttrs(), /*emitError=*/{},
185  regionBuilder);
186 }
187 
188 static void buildMatmulOp(OpBuilder &b, OperationState &state,
189  std::optional<TypeRange> resultTensorTypes,
190  ValueRange inputs, ValueRange outputs,
191  ArrayRef<NamedAttribute> attributes,
192  RegionBuilderFn regionBuilder,
193  ArrayRef<AffineMap> indexingMaps) {
194  // Initialize indexingMaps attribute, for MatmulOp.
195  SmallVector<Attribute, 3> indexingMapsAttrVal;
196  indexingMapsAttrVal = llvm::map_to_vector(
197  MatmulOp::getDefaultIndexingMaps(b.getContext()),
198  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
199  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
200  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
201  attributes, regionBuilder);
202 }
203 
205  std::optional<TypeRange> resultTensorTypes,
206  ValueRange inputs, ValueRange outputs,
207  ArrayRef<NamedAttribute> attributes,
208  RegionBuilderFn regionBuilder,
209  ArrayRef<AffineMap> indexingMaps) {
210  // Initialize indexingMaps attribute, for BatchMatmulOp.
211  SmallVector<Attribute, 4> indexingMapsAttrVal;
212  indexingMapsAttrVal =
213  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
214  return AffineMapAttr::get(map);
215  });
216  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
218  attributes, regionBuilder);
219 }
220 
222  std::optional<TypeRange> resultTensorTypes,
223  ValueRange inputs, ValueRange outputs,
224  ArrayRef<NamedAttribute> attributes,
225  RegionBuilderFn regionBuilder,
226  ArrayRef<AffineMap> indexingMaps) {
227  // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
228  SmallVector<Attribute, 4> indexingMapsAttrVal;
229  indexingMapsAttrVal =
230  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
231  return AffineMapAttr::get(map);
232  });
233  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
234  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
235  attributes, regionBuilder);
236 }
237 
238 /// Common parsing used for both named structured ops created by ods-gen and by
239 /// manually defined C++ ops. Does not handle regions.
240 static ParseResult
242  SmallVectorImpl<Type> &inputTypes,
243  SmallVectorImpl<Type> &outputTypes,
244  bool addOperandSegmentSizes = true) {
245  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
247  outputsOperands;
248 
249  if (succeeded(parser.parseOptionalLess())) {
250  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
251  return failure();
252  }
253  attrsLoc = parser.getCurrentLocation();
254  if (parser.parseOptionalAttrDict(result.attributes))
255  return failure();
256 
257  if (succeeded(parser.parseOptionalKeyword("ins"))) {
258  if (parser.parseLParen())
259  return failure();
260 
261  inputsOperandsLoc = parser.getCurrentLocation();
262  if (parser.parseOperandList(inputsOperands) ||
263  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
264  return failure();
265  }
266 
267  if (succeeded(parser.parseOptionalKeyword("outs"))) {
268  outputsOperandsLoc = parser.getCurrentLocation();
269  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
270  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
271  return failure();
272  }
273 
274  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
275  result.operands) ||
276  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
277  result.operands))
278  return failure();
279 
280  if (addOperandSegmentSizes) {
281  // This is a bit complex because we're trying to be backward compatible with
282  // operation syntax that mix the inherent attributes and the discardable
283  // ones in the same dictionary. If the properties are used, we append the
284  // operandSegmentSizes there directly. Otherwise we append it to the
285  // discardable attributes dictionary where it is handled by the generic
286  // Operation::create(...) method.
287  if (result.propertiesAttr) {
288  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
289  attrs.append("operandSegmentSizes",
291  {static_cast<int32_t>(inputsOperands.size()),
292  static_cast<int32_t>(outputsOperands.size())}));
293  result.propertiesAttr = attrs.getDictionary(parser.getContext());
294  } else {
295  result.addAttribute("operandSegmentSizes",
297  {static_cast<int32_t>(inputsOperands.size()),
298  static_cast<int32_t>(outputsOperands.size())}));
299  }
300  }
301  if (!result.propertiesAttr) {
302  std::optional<RegisteredOperationName> info =
303  result.name.getRegisteredInfo();
304  if (info) {
305  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
306  return parser.emitError(attrsLoc)
307  << "'" << result.name.getStringRef() << "' op ";
308  })))
309  return failure();
310  }
311  }
312  return success();
313 }
314 
316  ValueRange outputs) {
317  if (!inputs.empty())
318  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
319  if (!outputs.empty())
320  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // Specific parsing and printing for named structured ops created by ods-gen.
325 //===----------------------------------------------------------------------===//
326 
327 static ParseResult parseNamedStructuredOpRegion(
328  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
329  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
330  RegionBuilderFn regionBuilder, SMLoc loc) {
331  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
332  return parser.emitError(
333  parser.getCurrentLocation(),
334  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
335  "region expects {0} args, got {1}",
336  numRegionArgs, inputTypes.size() + outputTypes.size()));
337  }
338 
339  OpBuilder opBuilder(parser.getContext());
340  ParseResult result = success();
342  opBuilder, region, inputTypes, outputTypes, attrs,
343  [&]() {
344  result = failure();
345  return parser.emitError(loc);
346  },
347  regionBuilder);
348  return result;
349 }
350 
351 static ParseResult
353  SmallVectorImpl<Type> &resultTypes) {
354  if (parser.parseOptionalArrowTypeList(resultTypes))
355  return failure();
356  return success();
357 }
358 
359 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
360  OperationState &result,
361  unsigned numRegionArgs,
362  RegionBuilderFn regionBuilder) {
363  // TODO: Enable when ods-gen supports captures.
364  SmallVector<Type, 1> inputTypes, outputTypes;
365  SMLoc loc = parser.getCurrentLocation();
366  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
367  return failure();
368 
369  // Parse optional attributes.
370  if (parser.parseOptionalAttrDict(result.attributes))
371  return failure();
372 
373  // TODO: consider merging results parsing into region parsing.
374  // Need to wait for declarative assembly resolution to decide.
375  SmallVector<Type, 1> outputTensorsTypes;
376  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
377  return failure();
378  result.addTypes(outputTensorsTypes);
379 
380  std::unique_ptr<Region> region = std::make_unique<Region>();
381  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
382  outputTypes, result.attributes.getAttrs(),
383  regionBuilder, loc))
384  return failure();
385  result.addRegion(std::move(region));
386 
387  return success();
388 }
389 
391  TypeRange resultTypes) {
392  if (resultTypes.empty())
393  return;
394  p.printOptionalArrowTypeList(resultTypes);
395 }
396 
398  ValueRange inputs, ValueRange outputs,
399  ArrayRef<StringRef> elidedAttrs = {}) {
400  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
401 
402  // Printing is shared with generic ops, except for the region and
403  // attributes.
404  printCommonStructuredOpParts(p, inputs, outputs);
405 
406  // Results printing.
408 
409  // Region is elided.
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // Region builder helper.
414 // TODO: Move this to a utility library.
415 // The public methods on this class are referenced directly from generated code.
416 // Helper build the unary, binary, and type conversion functions defined by the
417 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
418 // class.
419 //
420 // Implementations of the math functions must be polymorphic over numeric types,
421 // internally performing necessary casts. If the function application makes no
422 // sense, then the only recourse is to assert and return nullptr. This can be
423 // extended later if it becomes possible to fail construction of the region. The
424 // invariant should be enforced at a higher level.
425 //
426 // TODO: These helpers are currently type polymorphic over the class of integer
427 // and floating point types, but they will not internally cast within bit
428 // widths of a class (mixed precision such as i8->i32) or across classes
429 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
430 // to be handled with care and work is being considered to extend the op
431 // language to make such cases explicit. In the mean-time, violating this will
432 // fail verification, which is deemed acceptable.
433 //===----------------------------------------------------------------------===//
434 
435 namespace {
436 
437 class RegionBuilderHelper {
438 public:
439  RegionBuilderHelper(OpBuilder &builder, Block &block)
440  : builder(builder), block(block) {}
441 
442  // Build the unary functions defined by OpDSL.
443  Value buildUnaryFn(UnaryFn unaryFn, Value arg,
445  if (!isFloatingPoint(arg)) {
446  if (emitError) {
447  emitError() << "unsupported non numeric type";
448  return nullptr;
449  }
450  llvm_unreachable("unsupported non numeric type");
451  }
452  OpBuilder::InsertionGuard g(builder);
453  builder.setInsertionPointToEnd(&block);
454  switch (unaryFn) {
455  case UnaryFn::exp:
456  return math::ExpOp::create(builder, arg.getLoc(), arg);
457  case UnaryFn::log:
458  return math::LogOp::create(builder, arg.getLoc(), arg);
459  case UnaryFn::abs:
460  return math::AbsFOp::create(builder, arg.getLoc(), arg);
461  case UnaryFn::ceil:
462  return math::CeilOp::create(builder, arg.getLoc(), arg);
463  case UnaryFn::floor:
464  return math::FloorOp::create(builder, arg.getLoc(), arg);
465  case UnaryFn::negf:
466  return arith::NegFOp::create(builder, arg.getLoc(), arg);
467  case UnaryFn::reciprocal: {
468  Attribute oneAttr = builder.getOneAttr(arg.getType());
469  auto one = arith::ConstantOp::create(builder, arg.getLoc(),
470  ::cast<TypedAttr>(oneAttr));
471  return arith::DivFOp::create(builder, arg.getLoc(), one, arg);
472  }
473  case UnaryFn::round:
474  return math::RoundOp::create(builder, arg.getLoc(), arg);
475  case UnaryFn::sqrt:
476  return math::SqrtOp::create(builder, arg.getLoc(), arg);
477  case UnaryFn::rsqrt:
478  return math::RsqrtOp::create(builder, arg.getLoc(), arg);
479  case UnaryFn::square:
480  return arith::MulFOp::create(builder, arg.getLoc(), arg, arg);
481  case UnaryFn::tanh:
482  return math::TanhOp::create(builder, arg.getLoc(), arg);
483  case UnaryFn::erf:
484  return math::ErfOp::create(builder, arg.getLoc(), arg);
485  }
486  if (emitError) {
487  emitError() << "unsupported unary function";
488  return nullptr;
489  }
490  llvm_unreachable("unsupported unary function");
491  }
492 
493  // Build the binary functions defined by OpDSL.
494  // If emitError is provided, an error will be emitted if the operation is not
495  // supported and a nullptr will be returned, otherwise an assertion will be
496  // raised.
497  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
499  bool allComplex = isComplex(arg0) && isComplex(arg1);
500  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
501  bool allInteger = isInteger(arg0) && isInteger(arg1);
502  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
503  arg1.getType().getIntOrFloatBitWidth() == 1;
504  if (!allComplex && !allFloatingPoint && !allInteger) {
505  if (emitError) {
506  emitError()
507  << "Cannot build binary Linalg operation: expects allComplex, "
508  "allFloatingPoint, or allInteger, got "
509  << arg0.getType() << " and " << arg1.getType();
510  return nullptr;
511  }
512  llvm_unreachable("unsupported non numeric type");
513  }
514  OpBuilder::InsertionGuard g(builder);
515  builder.setInsertionPointToEnd(&block);
516  switch (binaryFn) {
517  case BinaryFn::add:
518  if (allComplex)
519  return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1);
520  if (allFloatingPoint)
521  return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1);
522  if (allBool)
523  return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1);
524  return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1);
525  case BinaryFn::sub:
526  if (allComplex)
527  return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1);
528  if (allFloatingPoint)
529  return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1);
530  if (allBool) {
531  if (emitError) {
532  emitError() << "unsupported operation: sub with bools";
533  return nullptr;
534  }
535  llvm_unreachable("unsupported operation: sub with bools");
536  }
537  return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1);
538  case BinaryFn::mul:
539  if (allComplex)
540  return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1);
541  if (allFloatingPoint)
542  return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1);
543  if (allBool)
544  return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1);
545  return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1);
546  case BinaryFn::div:
547  if (allComplex)
548  return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1);
549  if (allFloatingPoint)
550  return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1);
551  if (allBool) {
552  if (emitError) {
553  emitError() << "unsupported operation: div with bools";
554  return nullptr;
555  }
556  llvm_unreachable("unsupported operation: div with bools");
557  }
558  return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1);
559  case BinaryFn::div_unsigned:
560  if (!allInteger || allBool) {
561  if (emitError) {
562  emitError() << "unsupported operation: unsigned div not on uint";
563  return nullptr;
564  }
565  llvm_unreachable("unsupported operation: unsigned div not on uint");
566  }
567  return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1);
568  case BinaryFn::max_signed:
569  assert(!allComplex);
570  if (allFloatingPoint)
571  return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
572  return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1);
573  case BinaryFn::min_signed:
574  assert(!allComplex);
575  if (allFloatingPoint)
576  return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
577  return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1);
578  case BinaryFn::max_unsigned:
579  assert(!allComplex);
580  if (allFloatingPoint)
581  return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1);
582  return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1);
583  case BinaryFn::min_unsigned:
584  assert(!allComplex);
585  if (allFloatingPoint)
586  return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1);
587  return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1);
588  case BinaryFn::powf:
589  assert(allFloatingPoint);
590  return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1);
591  }
592  if (emitError) {
593  emitError() << "unsupported binary function";
594  return nullptr;
595  }
596  llvm_unreachable("unsupported binary function");
597  }
598 
599  // Build the ternary functions defined by OpDSL.
600  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
602  bool headBool =
603  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
604  bool tailFloatingPoint =
605  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
606  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
607  OpBuilder::InsertionGuard g(builder);
608  builder.setInsertionPointToEnd(&block);
609  switch (ternaryFn) {
610  case TernaryFn::select:
611  if (!headBool && !(tailFloatingPoint || tailInteger))
612  llvm_unreachable("unsupported non numeric type");
613  return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2);
614  }
615  if (emitError) {
616  emitError() << "unsupported ternary function";
617  return nullptr;
618  }
619  llvm_unreachable("unsupported ternary function");
620  }
621 
622  // Build the type functions defined by OpDSL.
623  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
625  switch (typeFn) {
626  case TypeFn::cast_signed:
627  return cast(toType, operand, false);
628  case TypeFn::cast_unsigned:
629  return cast(toType, operand, true);
630  }
631  if (emitError) {
632  emitError() << "unsupported type conversion function";
633  return nullptr;
634  }
635  llvm_unreachable("unsupported type conversion function");
636  }
637 
638  void yieldOutputs(ValueRange values) {
639  OpBuilder::InsertionGuard g(builder);
640  builder.setInsertionPointToEnd(&block);
641  Location loc = builder.getUnknownLoc();
642  YieldOp::create(builder, loc, values);
643  }
644 
645  Value constant(const std::string &value) {
646  OpBuilder::InsertionGuard g(builder);
647  builder.setInsertionPointToEnd(&block);
648  Location loc = builder.getUnknownLoc();
649  Attribute valueAttr = parseAttribute(value, builder.getContext());
650  return arith::ConstantOp::create(builder, loc,
651  ::cast<TypedAttr>(valueAttr));
652  }
653 
654  Value index(int64_t dim) {
655  OpBuilder::InsertionGuard g(builder);
656  builder.setInsertionPointToEnd(&block);
657  return IndexOp::create(builder, builder.getUnknownLoc(), dim);
658  }
659 
660  Type getIntegerType(unsigned width) {
661  return IntegerType::get(builder.getContext(), width);
662  }
663 
664  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
665  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
666 
667 private:
668  // Generates operations to cast the given operand to a specified type.
669  // If the cast cannot be performed, a warning will be issued and the
670  // operand returned as-is (which will presumably yield a verification
671  // issue downstream).
672  Value cast(Type toType, Value operand, bool isUnsignedCast) {
673  OpBuilder::InsertionGuard g(builder);
674  builder.setInsertionPointToEnd(&block);
675  auto loc = operand.getLoc();
676  if (isa<UnknownLoc>(loc)) {
677  if (operand.getDefiningOp())
678  loc = operand.getDefiningOp()->getLoc();
679  else if (operand.getParentBlock() &&
680  operand.getParentBlock()->getParentOp())
681  loc = operand.getParentBlock()->getParentOp()->getLoc();
682  }
683  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
684  }
685 
686  bool isComplex(Value value) {
687  return llvm::isa<ComplexType>(value.getType());
688  }
689  bool isFloatingPoint(Value value) {
690  return llvm::isa<FloatType>(value.getType());
691  }
692  bool isInteger(Value value) {
693  return llvm::isa<IntegerType>(value.getType());
694  }
695 
696  OpBuilder &builder;
697  Block &block;
698 };
699 
700 } // namespace
701 
702 //===----------------------------------------------------------------------===//
703 // CopyOp
704 //===----------------------------------------------------------------------===//
705 
706 namespace {
707 
708 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
710  LogicalResult matchAndRewrite(CopyOp copyOp,
711  PatternRewriter &rewriter) const override {
712  if (copyOp.getInputs() != copyOp.getOutputs())
713  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
714  if (copyOp.hasPureBufferSemantics())
715  rewriter.eraseOp(copyOp);
716  else
717  rewriter.replaceOp(copyOp, copyOp.getInputs());
718 
719  return success();
720  }
721 };
722 
723 } // namespace
724 
725 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
726  MLIRContext *context) {
727  results.add<EraseSelfCopy>(context);
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // FillOp
732 //===----------------------------------------------------------------------===//
733 
734 namespace {
735 
736 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
737 ///
738 /// For such op chains, we can create new linalg.fill ops with the result
739 /// type of the tensor.expand/collapse_shape op.
740 template <typename TensorReshapeOp>
741 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
743  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
744  PatternRewriter &rewriter) const override {
745  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
746  if (!oldFill)
747  return failure();
748 
749  Location loc = oldFill.getLoc();
750  TensorReshapeOp newInit;
751  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
752 
753  newInit = TensorReshapeOp::create(
754  rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
755  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
756  reshapeOp.getStaticOutputShape());
757  } else {
758  newInit = TensorReshapeOp::create(
759  rewriter, loc, reshapeOp.getResultType(), oldFill.output(),
760  reshapeOp.getReassociation());
761  }
762  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
763  ValueRange{newInit});
764  return success();
765  }
766 };
767 
768 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
769 /// filling value are the same.
770 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
772 
773  LogicalResult matchAndRewrite(tensor::PadOp padOp,
774  PatternRewriter &rewriter) const override {
775  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
776  if (!fillOp)
777  return failure();
778 
779  // We can only fold if the padding value is the same as the original
780  // filling value.
781  Value padValue = padOp.getConstantPaddingValue();
782  if (!padValue || fillOp.value() != padValue)
783  return failure();
784 
785  ReifiedRankedShapedTypeDims reifiedShape;
786  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
787  return rewriter.notifyMatchFailure(
788  padOp, "failed to reify tensor.pad op result shape");
789 
790  auto emptyTensor =
791  tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(),
792  padOp.getResultType().getElementType());
793  Value replacement =
794  FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue},
795  ValueRange{emptyTensor})
796  .getResult(0);
797  if (replacement.getType() != padOp.getResultType()) {
798  replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(),
799  padOp.getResultType(), replacement);
800  }
801  rewriter.replaceOp(padOp, replacement);
802  return success();
803  }
804 };
805 
806 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
807 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
808 /// filling value are the same.
809 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
811 
812  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
813  PatternRewriter &rewriter) const override {
814  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
815  if (!srcPadOp)
816  return failure();
817 
818  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
819  return failure();
820 
821  // Walk back the tensor.insert_slice chain and find the first destination
822  // value at the start of the chain.
823  Value firstDest = insertOp.getDest();
824  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
825  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
826  return failure();
827 
828  // Make sure the range of values accessed are disjoint. Without this, we
829  // cannot fold tensor.pad away.
830  bool disjoint = false;
831  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
832  // If the dimension has dynamic offset/size, we cannot guarantee
833  // disjoint. So just skip it.
834  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
835  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
836  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
837  continue;
838 
839  // Get the range start and end, inclusively for both.
840  int64_t prevStart = prevOp.getStaticOffset(i);
841  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
842  prevOp.getStaticStride(i);
843  int64_t nextStart = insertOp.getStaticOffset(i);
844  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
845  insertOp.getStaticStride(i);
846  if (prevEnd < nextStart || nextEnd < prevStart) {
847  disjoint = true;
848  break;
849  }
850  }
851 
852  if (!disjoint)
853  break;
854  firstDest = prevOp.getDest();
855  }
856 
857  // Check whether the first destination is a fill op. For overlapped cases,
858  // this also cannot be true.
859  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
860  if (!dstFillOp)
861  return failure();
862 
863  // We can only fold if the padding value is the same as the original
864  // filling value.
865  Value padValue = srcPadOp.getConstantPaddingValue();
866  if (!padValue || dstFillOp.value() != padValue)
867  return failure();
868 
869  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
870  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
871 
872  Location loc = insertOp.getLoc();
873  MLIRContext *context = getContext();
874 
875  AffineExpr sym0, sym1;
876  bindSymbols(context, sym0, sym1);
877  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
878 
879  // Calculate the new offsets for the insert. It should be the old offsets
880  // plus low padding sizes.
881  SmallVector<OpFoldResult, 4> newOffsets;
882  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
883  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
884  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
885  }
886 
887  RankedTensorType srcPadType = srcPadOp.getSourceType();
889  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
890  if (srcPadType.isDynamicDim(i)) {
891  newSizes.push_back(
892  tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i)
893  .getResult());
894  } else {
895  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
896  }
897  }
898 
899  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
900  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
901  newSizes, insertOp.getMixedStrides());
902  return success();
903  }
904 };
905 
906 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
907 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
908 public:
910 
911  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
912  PatternRewriter &rewriter) const override {
913  // See if tensor input of tensor.extract op is the result of a linalg.fill
914  // op.
915  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
916  if (!fillOp)
917  return failure();
918 
919  // Get scalar input operand of linalg.fill op.
920  Value extractedScalar = fillOp.getInputs()[0];
921 
922  // Replace tensor.extract op with scalar value used to fill the tensor.
923  rewriter.replaceOp(extractOp, extractedScalar);
924  return success();
925  }
926 };
927 
928 /// Folds pack(fill) into a single fill op if
929 /// 1. The pack op does not have padding value, or
930 /// 2. The filled value and padding value are the same.
931 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
932  linalg::PackOp packOp) {
933  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
934  if (!fillOp)
935  return failure();
936 
937  if (auto paddingValue = packOp.getPaddingValue())
938  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
939  return failure();
940 
941  Value packOpDest = packOp.getDest();
942  if (!packOpDest.hasOneUse())
943  return failure();
944 
945  return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(),
946  packOp.getDest());
947 }
948 
949 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
950 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
951 public:
952  FoldFillWithPack(MLIRContext *context)
953  : OpRewritePattern<linalg::PackOp>(context) {}
954 
955  LogicalResult matchAndRewrite(linalg::PackOp packOp,
956  PatternRewriter &rewriter) const override {
957  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
958  if (failed(fillOp))
959  return failure();
960  rewriter.replaceOp(packOp, fillOp.value().result());
961  return success();
962  }
963 };
964 
965 /// Fold fill with copy.
966 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
968 
969  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
970  PatternRewriter &rewriter) const override {
971  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
972  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
973  fillOp.getInputs(),
974  copyOp.getOutputs());
975  return success();
976  }
977  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
978  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
979  fillOp.getOutputs());
980  return success();
981  }
982  return failure();
983  }
984 };
985 
986 /// Fold fill with transpose.
987 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
989 
990  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
991  PatternRewriter &rewriter) const override {
992  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
993  rewriter.replaceOpWithNewOp<FillOp>(
994  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
995  transposeOp.getDpsInitOperand(0)->get());
996  return success();
997  }
998  return failure();
999  }
1000 };
1001 
1002 /// Fold a concat with all elements being fills of the same value
1003 /// into a fill of the concat result shape.
1004 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1006 
1007  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1008  PatternRewriter &rewriter) const override {
1009  auto concatOperands = concatOp.getInputs();
1010  if (concatOperands.empty()) {
1011  return failure();
1012  }
1013 
1014  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1015  if (!firstFillOp) {
1016  return failure();
1017  }
1018  // Prefetch the fill value.
1019  OpFoldResult firstFillVal =
1020  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
1021  // Collect all the outs values for the fill operations.
1022  SmallVector<Value> allOuts;
1023  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
1024 
1025  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1026  auto fillOp = v.getDefiningOp<linalg::FillOp>();
1027  if (!fillOp) {
1028  return false;
1029  }
1030 
1031  OpFoldResult fillVal =
1032  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
1033  if (fillVal != firstFillVal)
1034  return false;
1035 
1036  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
1037  return true;
1038  };
1039  if (!llvm::all_of(concatOperands.drop_front(),
1040  isDefinedByCompatibleFillOp)) {
1041  return rewriter.notifyMatchFailure(
1042  concatOp, "not all operands are defined by a compatible fill op");
1043  }
1044 
1045  Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(),
1046  concatOp.getDim(), allOuts);
1047  rewriter.replaceOpWithNewOp<linalg::FillOp>(
1048  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
1049  return success();
1050  }
1051 };
1052 
1053 } // namespace
1054 
1055 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1056  MLIRContext *context) {
1057  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1058  FoldFillWithPack, FoldFillWithPad,
1059  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1060  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1061  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1062 }
1063 
1064 //===----------------------------------------------------------------------===//
1065 // GenericOp
1066 //===----------------------------------------------------------------------===//
1067 
1069  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1070  ValueRange outputs,
1071  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1072  SmallVector<Type, 4> blockArgTypes;
1073  SmallVector<Location, 4> blockArgLocs;
1074  for (ValueRange container : {inputs, outputs}) {
1075  for (Value v : container) {
1076  Type t = v.getType();
1077  blockArgTypes.push_back(
1078  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
1079  blockArgLocs.push_back(v.getLoc());
1080  }
1081  }
1082 
1083  OpBuilder::InsertionGuard guard(builder);
1084  Block *bodyBlock =
1085  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1086  bodyBuild(builder, loc, bodyBlock->getArguments());
1087 }
1088 
1089 void GenericOp::getAsmBlockArgumentNames(Region &region,
1090  OpAsmSetValueNameFn setNameFn) {
1091  for (Value v : getRegionInputArgs())
1092  setNameFn(v, "in");
1093  for (Value v : getRegionOutputArgs())
1094  setNameFn(v, "out");
1095 }
1096 
1097 void GenericOp::build(
1098  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1099  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1100  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1101  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1102  ArrayRef<NamedAttribute> attributes) {
1103  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1104  iteratorTypes, doc, libraryCall);
1105  result.addAttributes(attributes);
1106  if (bodyBuild)
1107  buildGenericRegion(builder, result.location, *result.regions.front(),
1108  inputs, outputs, bodyBuild);
1109 }
1110 
1111 void GenericOp::build(
1112  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1113  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1114  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1115  StringRef libraryCall,
1116  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1117  ArrayRef<NamedAttribute> attributes) {
1118  build(builder, result, resultTensorTypes, inputs, outputs,
1119  builder.getAffineMapArrayAttr(indexingMaps),
1120  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1121  iteratorTypes,
1122  [&](utils::IteratorType iter) -> mlir::Attribute {
1123  return IteratorTypeAttr::get(builder.getContext(), iter);
1124  }))),
1125  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1126  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1127  bodyBuild, attributes);
1128 }
1129 
1130 void GenericOp::build(
1131  OpBuilder &builder, OperationState &result, ValueRange inputs,
1132  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1133  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1134  StringRef libraryCall,
1135  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1136  ArrayRef<NamedAttribute> attributes) {
1137  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1138  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1139 }
1140 
1141 void GenericOp::build(
1142  OpBuilder &builder, OperationState &result, ValueRange inputs,
1143  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1144  ArrayRef<utils::IteratorType> iteratorTypes,
1145  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1146  ArrayRef<NamedAttribute> attributes) {
1147  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1148  /*doc=*/"",
1149  /*libraryCall=*/"", bodyBuild, attributes);
1150 }
1151 
1152 void GenericOp::build(
1153  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1154  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1155  ArrayRef<utils::IteratorType> iteratorTypes,
1156  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1157  ArrayRef<NamedAttribute> attributes) {
1158  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1159  iteratorTypes,
1160  /*doc=*/"",
1161  /*libraryCall=*/"", bodyBuild, attributes);
1162 }
1163 
1164 void GenericOp::print(OpAsmPrinter &p) {
1165  p << " ";
1166 
1167  // Print extra attributes.
1168  auto genericAttrNames = linalgTraitAttrNames();
1169 
1170  llvm::StringSet<> genericAttrNamesSet;
1171  genericAttrNamesSet.insert_range(genericAttrNames);
1172  SmallVector<NamedAttribute, 8> genericAttrs;
1173  for (auto attr : (*this)->getAttrs()) {
1174  if (attr.getName() == getIteratorTypesAttrName()) {
1175  auto iteratorTypes =
1176  llvm::cast<ArrayAttr>(attr.getValue())
1177  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1178  // Convert IteratorType enums into the string representation. This is
1179  // needed, because tests still use the old format when 'iterator_types'
1180  // attribute is represented as an array of strings.
1181  // TODO: Remove this conversion once tests are fixed.
1182  SmallVector<Attribute> iteratorTypeNames =
1183  llvm::to_vector(llvm::map_range(
1184  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1185  return StringAttr::get(getContext(), stringifyIteratorType(t));
1186  }));
1187 
1188  genericAttrs.emplace_back(
1189  getIteratorTypesAttrName(),
1190  ArrayAttr::get(getContext(), iteratorTypeNames));
1191  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1192  genericAttrs.push_back(attr);
1193  }
1194  }
1195  if (!genericAttrs.empty()) {
1196  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1197  p << genericDictAttr;
1198  }
1199 
1200  // Printing is shared with named ops, except for the region and attributes
1201  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1202 
1203  genericAttrNames.push_back("operandSegmentSizes");
1204  genericAttrNamesSet.insert(genericAttrNames.back());
1205 
1206  bool hasExtraAttrs = false;
1207  for (NamedAttribute n : (*this)->getAttrs()) {
1208  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1209  break;
1210  }
1211  if (hasExtraAttrs) {
1212  p << " attrs = ";
1213  p.printOptionalAttrDict((*this)->getAttrs(),
1214  /*elidedAttrs=*/genericAttrNames);
1215  }
1216 
1217  // Print region.
1218  if (!getRegion().empty()) {
1219  p << ' ';
1220  p.printRegion(getRegion());
1221  }
1222 
1223  // Print results.
1224  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1225 }
1226 
1227 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1228  DictionaryAttr dictAttr;
1229  // Parse the core linalg traits that must check into a dictAttr.
1230  // The name is unimportant as we will overwrite result.attributes.
1231  // The core linalg traits must contain the information necessary to pass the
1232  // verifier.
1233  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1234  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1235  return failure();
1236  result.attributes.assign(dictAttr.getValue().begin(),
1237  dictAttr.getValue().end());
1238 
1239  // Convert array of string into an array of IteratorType enums. This is
1240  // needed, because tests still use the old format when 'iterator_types'
1241  // attribute is represented as an array of strings.
1242  // TODO: Remove this conversion once tests are fixed.
1243  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1244  result.attributes.get(getIteratorTypesAttrName(result.name)));
1245  if (!iteratorTypes) {
1246  return parser.emitError(attributeLocation)
1247  << "expected " << getIteratorTypesAttrName(result.name)
1248  << " array attribute";
1249  }
1250 
1251  SmallVector<Attribute> iteratorTypeAttrs;
1252 
1253  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1254  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1255  if (!maybeIteratorType.has_value())
1256  return parser.emitError(parser.getCurrentLocation())
1257  << "unexpected iterator_type (" << s << ")";
1258 
1259  iteratorTypeAttrs.push_back(
1260  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1261  }
1262  result.attributes.set(getIteratorTypesAttrName(result.name),
1263  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1264 
1265  // Parsing is shared with named ops, except for the region.
1266  SmallVector<Type, 1> inputTypes, outputTypes;
1267  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1268  return failure();
1269 
1270  // Optional attributes may be added.
1271  if (succeeded(parser.parseOptionalKeyword("attrs")))
1272  if (failed(parser.parseEqual()) ||
1273  failed(parser.parseOptionalAttrDict(result.attributes)))
1274  return failure();
1275 
1276  std::unique_ptr<Region> region = std::make_unique<Region>();
1277  if (parser.parseRegion(*region, {}))
1278  return failure();
1279  result.addRegion(std::move(region));
1280 
1281  // Generic ops may specify that a subset of its outputs are tensors. Such
1282  // outputs are specified in the result type.
1283  // TODO: may need to move output parsing before region parsing.
1284  // Need to wait for declarative assembly resolution to decide.
1285  SmallVector<Type, 1> outputTensorsTypes;
1286  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1287  return failure();
1288  result.addTypes(outputTensorsTypes);
1289 
1290  return success();
1291 }
1292 
1295  &effects,
1296  LinalgOp linalgOp) {
1297  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1298  if (!llvm::isa<MemRefType>(operand.getType()))
1299  continue;
1300  effects.emplace_back(
1301  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1302  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1303  }
1304 
1305  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1306  if (!llvm::isa<MemRefType>(operand.get().getType()))
1307  continue;
1308  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1309  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1310  /*effectOnFullRegion=*/true,
1312  }
1313  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1314  /*effectOnFullRegion=*/true,
1316  }
1317 }
1318 
1319 void GenericOp::getEffects(
1321  &effects) {
1322  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1323 }
1324 
1326 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1327  // Operands with value semantics are speculatable, while operands with memory
1328  // semantics are not.
1329  if (!linalgOp.hasPureTensorSemantics())
1331  // The body of the op can still have speculation in its region.
1333 }
1334 
1335 Speculation::Speculatability GenericOp::getSpeculatability() {
1336  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1337 }
1338 
1339 LogicalResult GenericOp::verify() { return success(); }
1340 
1341 namespace {
1342 
1343 /// Remove linalg operations that are just copying the values from inputs to
1344 /// results. In the memref case, the operation must be copying to and from the
1345 /// same value. Requirements are:
1346 /// 1) All iterator types are parallel
1347 /// 2) The body contains just a yield operation with the yielded values being
1348 /// the arguments corresponding to the operands.
1349 template <typename OpTy>
1350 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1352 
1353  LogicalResult matchAndRewrite(OpTy linalgOp,
1354  PatternRewriter &rewriter) const override {
1355  // All indexing maps must be equal. It follows that they are permutations.
1356  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1357  return failure();
1358 
1359  // Check that the body of the linalg operation is just a linalg.yield
1360  // operation.
1361  Block &body = linalgOp->getRegion(0).front();
1362  if (!llvm::hasSingleElement(body))
1363  return failure();
1364  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1365  if (!yieldOp)
1366  return failure();
1367 
1368  // In the buffer case, we need to check exact buffer equality.
1369  if (linalgOp.hasPureBufferSemantics()) {
1370  if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1371  linalgOp.getDpsInputOperand(0)->get() !=
1372  linalgOp.getDpsInitOperand(0)->get()) {
1373  return rewriter.notifyMatchFailure(
1374  linalgOp, "expected single input and output to be the same value");
1375  }
1376 
1377  auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1378  if (!yieldArg || yieldArg.getOwner() != &body) {
1379  return rewriter.notifyMatchFailure(linalgOp,
1380  "cannot fold fill-like op");
1381  }
1382 
1383  rewriter.eraseOp(linalgOp);
1384  return success();
1385  }
1386 
1387  if (!linalgOp.hasPureTensorSemantics()) {
1388  return rewriter.notifyMatchFailure(
1389  linalgOp, "mixed semantics is not supported yet");
1390  }
1391 
1392  // Get the argument number of the returned values. That is the operand
1393  // number to use for replacing uses of this operation.
1394  SmallVector<Value> returnedArgs;
1395  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1396  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1397  if (!yieldArg || yieldArg.getOwner() != &body)
1398  return failure();
1399  unsigned argumentNumber = yieldArg.getArgNumber();
1400  Value returnedArg = linalgOp->getOperand(argumentNumber);
1401  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1402  // The input can have a different type than the result, e.g. a dynamic
1403  // input dimension can be turned into a static output dimension.
1404  Type returnType = returnedArg.getType();
1405  if (returnType != resultType) {
1406  // Distinguish between sparse conversion or dense tensor casting.
1407  // TODO: unify the two ops?
1408  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1410  returnedArg = sparse_tensor::ConvertOp::create(
1411  rewriter, linalgOp.getLoc(), resultType, returnedArg);
1412  else {
1413  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1414  resultType))
1415  return failure();
1416  returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(),
1417  resultType, returnedArg);
1418  }
1419  }
1420  returnedArgs.push_back(returnedArg);
1421  }
1422 
1423  if (returnedArgs.size() != linalgOp->getNumResults())
1424  return failure();
1425  rewriter.replaceOp(linalgOp, returnedArgs);
1426  return success();
1427  }
1428 };
1429 
1430 } // namespace
1431 
1432 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433  MLIRContext *context) {
1434  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1435 }
1436 
1437 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1438  return memref::foldMemRefCast(*this);
1439 }
1440 
1441 //===----------------------------------------------------------------------===//
1442 // MapOp
1443 //===----------------------------------------------------------------------===//
1444 
1445 static ParseResult parseDstStyleOp(
1446  OpAsmParser &parser, OperationState &result,
1447  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1448  nullptr) {
1449  // Parse `ins` and `outs`.
1450  SmallVector<Type, 4> inputTypes, outputTypes;
1451  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1452  /*addOperandSegmentSizes=*/false))
1453  return failure();
1454 
1455  // Add result types.
1456  for (Type outputType : outputTypes) {
1457  if (llvm::isa<RankedTensorType>(outputType))
1458  result.addTypes(outputType);
1459  }
1460 
1461  // Parse required attributes.
1462  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1463  return failure();
1464 
1465  // Parse optional attributes.
1466  if (parser.parseOptionalAttrDict(result.attributes))
1467  return failure();
1468  return success();
1469 }
1470 
1471 void MapOp::getAsmBlockArgumentNames(Region &region,
1472  OpAsmSetValueNameFn setNameFn) {
1473  for (Value v : getRegionInputArgs())
1474  setNameFn(v, "in");
1475 }
1476 
1477 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1478  if (!getResults().empty())
1479  setNameFn(getResults().front(), "mapped");
1480 }
1481 
1482 void MapOp::build(
1483  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1484  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1485  ArrayRef<NamedAttribute> attributes) {
1486  build(builder, result, TypeRange{}, inputs, init);
1487  result.addAttributes(attributes);
1488 
1489  // Add output types for `RankedTensorType` output arguments.
1490  Type initType = init.getType();
1491  if (llvm::isa<RankedTensorType>(initType))
1492  result.addTypes(initType);
1493 
1494  if (bodyBuild)
1495  buildGenericRegion(builder, result.location, *result.regions.front(),
1496  inputs, /*outputs=*/{}, bodyBuild);
1497 }
1498 
1499 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1500  const OperationName &payloadOpName,
1501  const NamedAttrList &payloadOpAttrs,
1502  ArrayRef<Value> operands,
1503  bool initFirst = false) {
1504  OpBuilder b(parser.getContext());
1505  Region *body = result.addRegion();
1506  Block &block = body->emplaceBlock();
1507  b.setInsertionPointToStart(&block);
1508  for (auto &operand : operands) {
1509  block.addArgument(
1510  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1511  b.getUnknownLoc());
1512  }
1513  SmallVector<Value> payloadOpOperands;
1514  // If initFirst flag is enabled, we consider init as the first position of
1515  // payload operands.
1516  if (initFirst) {
1517  payloadOpOperands.push_back(block.getArguments().back());
1518  for (const auto &arg : block.getArguments().drop_back())
1519  payloadOpOperands.push_back(arg);
1520  } else {
1521  payloadOpOperands = {block.getArguments().begin(),
1522  block.getArguments().end()};
1523  }
1524 
1525  Operation *payloadOp = b.create(
1526  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1527  payloadOpOperands,
1528  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1529  .getElementType()},
1530  payloadOpAttrs);
1531  YieldOp::create(b, result.location, payloadOp->getResults());
1532 }
1533 
1534 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1535  std::optional<OperationName> payloadOpName;
1536  NamedAttrList payloadOpAttrs;
1537  if (succeeded(parser.parseOptionalLBrace())) {
1538  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1539  if (failed(operationName))
1540  return failure();
1541  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1542  return failure();
1543  payloadOpName = operationName.value();
1544  if (parser.parseRBrace())
1545  return failure();
1546  }
1547 
1548  if (parseDstStyleOp(parser, result))
1549  return failure();
1550 
1551  if (payloadOpName.has_value()) {
1552  if (!result.operands.empty())
1553  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1554  payloadOpAttrs,
1555  ArrayRef(result.operands).drop_back());
1556  else
1557  result.addRegion();
1558  } else {
1560  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1561  /*allowType=*/true, /*allowAttrs=*/true)) {
1562  return failure();
1563  }
1564  Region *body = result.addRegion();
1565  if (parser.parseRegion(*body, regionArgs))
1566  return failure();
1567  }
1568  return success();
1569 }
1570 
1571 // Retrieve the operation from the body, if it is the only one (except
1572 // yield) and if it gets the same amount of arguments as the body does.
1573 // If initFirst flag is enabled, we check that init takes the first position in
1574 // operands of payload.
1575 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1576  if (body->getOperations().size() != 2)
1577  return nullptr;
1578  Operation &payload = body->getOperations().front();
1579  assert(isa<YieldOp>(body->getOperations().back()));
1580 
1581  if (payload.getNumOperands() == 0 ||
1582  payload.getNumOperands() != body->getNumArguments())
1583  return nullptr;
1584  if (initFirst) {
1585  // check init
1586  if (payload.getOperands().back() != body->getArgument(0))
1587  return nullptr;
1588  // check rest
1589  for (const auto &[operand, bbArg] :
1590  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1591  if (bbArg != operand)
1592  return nullptr;
1593  }
1594  } else {
1595  for (const auto &[operand, bbArg] :
1596  llvm::zip(payload.getOperands(), body->getArguments())) {
1597  if (bbArg != operand)
1598  return nullptr;
1599  }
1600  }
1601  return &payload;
1602 }
1603 
1604 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1605  SmallVector<StringRef> elidedAttrs;
1606  std::string attrToElide;
1607  p << " { " << payloadOp->getName().getStringRef();
1608  for (const auto &attr : payloadOp->getAttrs()) {
1609  auto fastAttr =
1610  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1611  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1612  attrToElide = attr.getName().str();
1613  elidedAttrs.push_back(attrToElide);
1614  break;
1615  }
1616  }
1617  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1618  p << " }";
1619 }
1620 
1621 void MapOp::print(OpAsmPrinter &p) {
1622  Block *mapper = getBody();
1623  Operation *payloadOp = findPayloadOp(mapper);
1624  if (payloadOp) {
1625  printShortForm(p, payloadOp);
1626  }
1627 
1628  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1629  p.printOptionalAttrDict((*this)->getAttrs());
1630 
1631  if (!payloadOp) {
1632  // Print region if the payload op was not detected.
1633  p.increaseIndent();
1634  p.printNewline();
1635  p << "(";
1636  llvm::interleaveComma(mapper->getArguments(), p,
1637  [&](auto arg) { p.printRegionArgument(arg); });
1638  p << ") ";
1639 
1640  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1641  p.decreaseIndent();
1642  }
1643 }
1644 
1645 LogicalResult MapOp::verify() {
1646  auto *bodyBlock = getBody();
1647  auto blockArgs = bodyBlock->getArguments();
1648 
1649  // Checks if the number of `inputs` match the arity of the `mapper` region.
1650  if (getInputs().size() != blockArgs.size())
1651  return emitOpError() << "expects number of operands to match the arity of "
1652  "mapper, but got: "
1653  << getInputs().size() << " and " << blockArgs.size();
1654 
1655  // The parameters of mapper should all match the element type of inputs.
1656  for (const auto &[bbArgType, inputArg] :
1657  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1658  auto inputElemType =
1659  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1660  if (bbArgType != inputElemType) {
1661  return emitOpError() << "expected element type of input " << inputElemType
1662  << " to match bbArg type " << bbArgType;
1663  }
1664  }
1665 
1666  // The shape of each input must match the shape of the output.
1667  auto outputShape = getInit().getType().getShape();
1668  for (Type inputArgType : TypeRange{getInputs()}) {
1669  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1670  if (inputElemShape != outputShape) {
1671  return emitOpError() << "expected shape of input (" << inputElemShape
1672  << ") to match shape of output (" << outputShape
1673  << ")";
1674  }
1675  }
1676 
1677  return success();
1678 }
1679 
1680 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1681  int64_t rank = getInit().getType().getRank();
1682  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1683 }
1684 
1685 ArrayAttr MapOp::getIndexingMaps() {
1686  Builder builder(getContext());
1687  int64_t rank = getInit().getType().getRank();
1688  int64_t numIndexingMaps = getOperands().size();
1690  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1691 }
1692 
1693 void MapOp::getEffects(
1695  &effects) {
1696  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1697 }
1698 
1699 Speculation::Speculatability MapOp::getSpeculatability() {
1700  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1701 }
1702 
1703 //===----------------------------------------------------------------------===//
1704 // ReduceOp
1705 //===----------------------------------------------------------------------===//
1706 
1707 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1708  OpAsmSetValueNameFn setNameFn) {
1709  for (Value v : getRegionInputArgs())
1710  setNameFn(v, "in");
1711  for (Value v : getRegionOutputArgs())
1712  setNameFn(v, "init");
1713 }
1714 
1715 void ReduceOp::getAsmResultNames(
1716  function_ref<void(Value, StringRef)> setNameFn) {
1717  if (!getResults().empty())
1718  setNameFn(getResults().front(), "reduced");
1719 }
1720 
1721 void ReduceOp::build(
1722  OpBuilder &builder, OperationState &result, ValueRange inputs,
1723  ValueRange inits, ArrayRef<int64_t> dimensions,
1724  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1725  ArrayRef<NamedAttribute> attributes) {
1726  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1727  result.addAttributes(attributes);
1728 
1729  // Add output types for `RankedTensorType` output arguments.
1730  for (Value init : inits) {
1731  Type initType = init.getType();
1732  if (llvm::isa<RankedTensorType>(initType))
1733  result.addTypes(initType);
1734  }
1735 
1736  if (bodyBuild)
1737  buildGenericRegion(builder, result.location, *result.regions.front(),
1738  inputs, inits, bodyBuild);
1739 }
1740 
1741 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1742  int64_t inputRank =
1743  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1744  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1745  utils::IteratorType::parallel);
1746  for (int64_t reductionDim : getDimensions())
1747  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1748  return iteratorTypes;
1749 }
1750 
1751 ArrayAttr ReduceOp::getIndexingMaps() {
1752  int64_t inputRank =
1753  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1754  SmallVector<AffineMap> affineMaps(
1755  getNumDpsInputs(),
1757  AffineMap resultMap =
1759  .dropResults(getDimensions());
1760  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1761  affineMaps.push_back(resultMap);
1762  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1763 }
1764 
1765 void ReduceOp::getEffects(
1767  &effects) {
1768  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1769 }
1770 
1771 Speculation::Speculatability ReduceOp::getSpeculatability() {
1772  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1773 }
1774 
1775 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1776  NamedAttrList &attributes,
1777  StringRef attributeName) {
1778  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1779  return failure();
1780 
1781  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1782  return success();
1783 }
1784 
1785 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1786  std::optional<OperationName> payloadOpName;
1787  NamedAttrList payloadOpAttrs;
1788  if (succeeded(parser.parseOptionalLBrace())) {
1789  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1790  if (failed(operationName))
1791  return failure();
1792  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1793  return failure();
1794  payloadOpName = operationName.value();
1795  if (parser.parseRBrace())
1796  return failure();
1797  }
1798 
1799  if (parseDstStyleOp(
1800  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1801  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1802  }))
1803  return failure();
1804 
1805  if (payloadOpName.has_value()) {
1806  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1807  ArrayRef(result.operands), /*initFirst=*/true);
1808  } else {
1810  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1811  /*allowType=*/true, /*allowAttrs=*/true)) {
1812  return failure();
1813  }
1814 
1815  Region *body = result.addRegion();
1816  if (parser.parseRegion(*body, regionArgs))
1817  return failure();
1818  }
1819 
1820  return success();
1821 }
1822 
1823 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1824  ArrayRef<int64_t> attributeValue) {
1825  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1826 }
1827 
1828 void ReduceOp::print(OpAsmPrinter &p) {
1829  Block *mapper = getBody();
1830  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1831  if (payloadOp) {
1832  printShortForm(p, payloadOp);
1833  }
1834 
1835  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1836  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1837  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1838  if (!payloadOp) {
1839  // Print region if the payload op was not detected.
1840  p.increaseIndent();
1841  p.printNewline();
1842  p << "(";
1843  llvm::interleaveComma(mapper->getArguments(), p,
1844  [&](auto arg) { p.printRegionArgument(arg); });
1845  p << ") ";
1846 
1847  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1848  p.decreaseIndent();
1849  }
1850 }
1851 
1852 LogicalResult ReduceOp::verify() {
1853  ArrayRef<int64_t> dimensionsRef = getDimensions();
1854 
1855  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1856  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1857  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1858  return emitOpError() << "expects all inputs to have the same shapes. "
1859  "Shape at input-index "
1860  << i
1861  << " is not equal to the shape at input-index 0.";
1862  }
1863  }
1864  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1865  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1866  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1867  return emitOpError() << "expects all outputs to have the same shapes. "
1868  "Shape at output-index "
1869  << i
1870  << " is not equal to the shape at output-index 0.";
1871  }
1872  }
1873  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1874  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1875 
1876  DenseSet<int64_t> dimensionsToReduce;
1877  for (int64_t dimension : dimensionsRef) {
1878  if (dimension < 0 || dimension >= inputType.getRank()) {
1879  return emitOpError()
1880  << "dimensions for reduction should be in the range [0, "
1881  << inputType.getRank() - 1 << "].";
1882  }
1883  dimensionsToReduce.insert(dimension);
1884  }
1885 
1886  auto inputDims = inputType.getShape();
1887  auto initDims = initType.getShape();
1888 
1889  // Input dimensions that will be left after the reduction.
1890  SmallVector<int64_t> reducedInputDims;
1891  for (const auto &en : llvm::enumerate(inputDims)) {
1892  if (!dimensionsToReduce.count(en.index()))
1893  reducedInputDims.push_back(en.value());
1894  }
1895 
1896  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1897  return emitOpError() << "number of dimensions after reduction "
1898  << reducedInputDims.size()
1899  << " doesn't match the init rank "
1900  << initType.getRank();
1901  }
1902 
1903  if (reducedInputDims != initDims)
1904  return emitOpError() << "init dimensions [" << initDims
1905  << "] doesn't match input dimensions after reduction ["
1906  << reducedInputDims << "]";
1907 
1908  Block *block = getBody();
1909  if (block->getNumArguments() != this->getNumOperands())
1910  return emitOpError()
1911  << "mismatching number of operands and block arguments";
1912 
1913  // Check that the first block arguments match the element type of the inputs.
1914  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1915  Type inputElementType =
1916  llvm::cast<ShapedType>(input.getType()).getElementType();
1917  if (inputElementType != bbArg.getType())
1918  return emitOpError()
1919  << "input element type " << inputElementType
1920  << " does not match corresponding block argument type "
1921  << bbArg.getType();
1922  }
1923 
1924  // Check that the last block arguments match the element type of the outputs.
1925  for (auto [output, bbArg] : llvm::zip(
1926  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1927  auto outputElementType =
1928  llvm::cast<ShapedType>(output.getType()).getElementType();
1929  if (outputElementType != bbArg.getType())
1930  return emitOpError()
1931  << "output element type " << outputElementType
1932  << " does not match corresponding block argument type "
1933  << bbArg.getType();
1934  }
1935  return success();
1936 }
1937 
1938 //===----------------------------------------------------------------------===//
1939 // TransposeOp
1940 //===----------------------------------------------------------------------===//
1941 
1942 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1943  Region &region, ValueRange inputs,
1944  ValueRange outputs) {
1945  buildGenericRegion(builder, loc, region, inputs, outputs,
1946  [](OpBuilder &b, Location loc, ValueRange args) {
1947  if (!args.empty())
1948  linalg::YieldOp::create(b, loc, args[0]);
1949  });
1950 }
1951 
1952 void TransposeOp::build(::mlir::OpBuilder &builder,
1953  ::mlir::OperationState &result, Value input, Value init,
1954  DenseI64ArrayAttr permutation,
1955  ArrayRef<NamedAttribute> attributes) {
1956  result.addOperands(input);
1957  result.addOperands(init);
1958  result.addAttribute(getPermutationAttrName(result.name), permutation);
1959  result.addAttributes(attributes);
1960 
1961  // Add output types for `RankedTensorType` output arguments.
1962  Type initType = init.getType();
1963  if (llvm::isa<RankedTensorType>(initType))
1964  result.addTypes(initType);
1965 
1966  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1967  init);
1968 }
1969 
1970 void TransposeOp::build(::mlir::OpBuilder &builder,
1971  ::mlir::OperationState &result, Value input, Value init,
1972  ArrayRef<int64_t> permutation,
1973  ArrayRef<NamedAttribute> attributes) {
1974  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1975  attributes);
1976 }
1977 
1978 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1979  if (failed(parseDstStyleOp(
1980  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1981  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1982  })))
1983  return failure();
1984 
1985  OpBuilder builder(parser.getContext());
1986  buildIdentityRegion(builder, result.location, *result.addRegion(),
1987  /*inputs=*/result.operands,
1988  /*outputs=*/{});
1989  return success();
1990 }
1991 
1992 void TransposeOp::getAsmResultNames(
1993  function_ref<void(Value, StringRef)> setNameFn) {
1994  if (!getResults().empty())
1995  setNameFn(getResults().front(), "transposed");
1996 }
1997 
1999  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2000  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
2001  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
2002 }
2003 
2004 LogicalResult TransposeOp::verify() {
2005  ArrayRef<int64_t> permutationRef = getPermutation();
2006 
2007  if (!isPermutationVector(permutationRef))
2008  return emitOpError("permutation is not valid");
2009 
2010  auto inputType = getInput().getType();
2011  auto initType = getInit().getType();
2012 
2013  int64_t rank = inputType.getRank();
2014 
2015  if (rank != initType.getRank())
2016  return emitOpError() << "input rank " << rank
2017  << " does not match init rank " << initType.getRank();
2018 
2019  if (rank != static_cast<int64_t>(permutationRef.size()))
2020  return emitOpError() << "size of permutation " << permutationRef.size()
2021  << " does not match the argument rank " << rank;
2022 
2023  auto inputDims = inputType.getShape();
2024  auto initDims = initType.getShape();
2025 
2026  for (int64_t i = 0; i < rank; ++i) {
2027  int64_t inputDim = inputDims[permutationRef[i]];
2028  int64_t initDim = initDims[i];
2029 
2030  if (inputDim != initDim) {
2031  return emitOpError() << "dim(result, " << i << ") = " << initDim
2032  << " doesn't match dim(input, permutation[" << i
2033  << "]) = " << inputDim;
2034  }
2035  }
2036 
2037  return success();
2038 }
2039 
2040 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2041  int64_t rank = getInit().getType().getRank();
2042  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2043 }
2044 
2045 ArrayAttr TransposeOp::getIndexingMaps() {
2046  Builder builder(getContext());
2047  int64_t rank = getInit().getType().getRank();
2048  return builder.getAffineMapArrayAttr(
2050  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
2051  builder.getMultiDimIdentityMap(rank)});
2052 }
2053 
2054 void TransposeOp::getEffects(
2056  &effects) {
2057  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2058 }
2059 
2060 Speculation::Speculatability TransposeOp::getSpeculatability() {
2061  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2062 }
2063 
2064 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2066  // Only the tensor type is supported.
2067  if (!isa<TensorType>(getInput().getType()))
2068  return failure();
2069 
2070  // Single dimension transpose.
2071  if (getPermutation().size() == 0) {
2072  result.push_back(getInput());
2073  return success();
2074  }
2075  // Identity permutation.
2076  if (isIdentityPermutation(getPermutation())) {
2077  result.push_back(getInput());
2078  return success();
2079  }
2080 
2081  return failure();
2082 }
2083 
2084 /// Fold transpose with transpose.
2085 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2087 
2088  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2089  PatternRewriter &rewriter) const override {
2090  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2091  if (!defTransposeOp)
2092  return failure();
2093  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2094  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2095  SmallVector<int64_t> foldedPerms;
2096  foldedPerms.reserve(perms.size());
2097  for (int64_t perm : perms)
2098  foldedPerms.push_back(defPerms[perm]);
2099 
2100  rewriter.replaceOpWithNewOp<TransposeOp>(
2101  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2102  foldedPerms);
2103  return success();
2104  }
2105 };
2106 
2107 /// This pattern canonicalize transpose by swapping the order of
2108 /// broadcast and transpose:
2109 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2110 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2112 
2113  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2114  PatternRewriter &rewriter) const override {
2115  Value input = transposeOp.getInput();
2116  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2117  if (!input.hasOneUse() || !broadcastOp)
2118  return failure();
2119 
2120  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2121  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2122 
2123  // Get new perms and new dimensions.
2124  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2125  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2126  SmallVector<int64_t> resultDimensions;
2127  unsigned dimensionSize = dimensions.size();
2128  for (unsigned i = 0; i < dimensionSize; ++i)
2129  resultDimensions.push_back(invertPerm[dimensions[i]]);
2130 
2131  // Create transpose result.
2132  Value broadcastInput = broadcastOp.getInput();
2133  Location loc = transposeOp.getLoc();
2134  MLIRContext *ctx = transposeOp.getContext();
2136  auto broadcastInputTy =
2137  mlir::cast<RankedTensorType>(broadcastInput.getType());
2138  unsigned inputRank = broadcastInputTy.getRank();
2139  for (unsigned i = 0; i < inputRank; ++i) {
2140  if (broadcastInputTy.isDynamicDim(i)) {
2141  dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i)
2142  ->getResult(0));
2143  } else {
2144  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2145  broadcastInputTy.getDimSize(i)));
2146  }
2147  }
2148  SmallVector<OpFoldResult> transposeResultShapes =
2149  applyPermutation(dims, resultPerms);
2150  Value transposeInit = tensor::EmptyOp::create(
2151  rewriter, transposeOp.getLoc(), transposeResultShapes,
2152  broadcastInputTy.getElementType());
2153 
2154  // Create broadcast(transpose(input)).
2155  Value transposeResult =
2156  TransposeOp::create(rewriter, loc, broadcastOp.getInput(),
2157  transposeInit, resultPerms)
2158  ->getResult(0);
2159  rewriter.replaceOpWithNewOp<BroadcastOp>(
2160  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2161  return success();
2162  }
2163 };
2164 
2165 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2166  MLIRContext *context) {
2168 }
2169 
2170 //===----------------------------------------------------------------------===//
2171 // BroadcastOp
2172 //===----------------------------------------------------------------------===//
2173 
2174 void BroadcastOp::build(::mlir::OpBuilder &builder,
2175  ::mlir::OperationState &result, Value input, Value init,
2176  DenseI64ArrayAttr dimensions,
2177  ArrayRef<NamedAttribute> attributes) {
2178  result.addOperands(input);
2179  result.addOperands(init);
2180  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2181  result.addAttributes(attributes);
2182 
2183  // Add output types for `RankedTensorType` output arguments.
2184  Type initType = init.getType();
2185  if (llvm::isa<RankedTensorType>(initType))
2186  result.addTypes(initType);
2187 
2188  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2189  init);
2190 }
2191 
2192 void BroadcastOp::build(::mlir::OpBuilder &builder,
2193  ::mlir::OperationState &result, Value input, Value init,
2194  ArrayRef<int64_t> dimensions,
2195  ArrayRef<NamedAttribute> attributes) {
2196  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2197  attributes);
2198 }
2199 
2200 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2201  if (failed(parseDstStyleOp(
2202  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2203  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2204  })))
2205  return failure();
2206 
2207  OpBuilder builder(parser.getContext());
2208  buildIdentityRegion(builder, result.location, *result.addRegion(),
2209  /*inputs=*/result.operands,
2210  /*outputs=*/{});
2211  return success();
2212 }
2213 
2214 void BroadcastOp::getAsmResultNames(
2215  function_ref<void(Value, StringRef)> setNameFn) {
2216  if (!getResults().empty())
2217  setNameFn(getResults().front(), "broadcasted");
2218 }
2219 
2221  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2222  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2223  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2224 }
2225 
2226 LogicalResult BroadcastOp::verify() {
2227  ArrayRef<int64_t> dimensionsRef = getDimensions();
2228 
2229  auto inputType = getInput().getType();
2230  auto initType = getInit().getType();
2231 
2232  int64_t inputRank = inputType.getRank();
2233  int64_t initRank = initType.getRank();
2234 
2235  auto inputShape = inputType.getShape();
2236  auto initShape = initType.getShape();
2237 
2238  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2239  return emitOpError() << "input rank plus added dimensions does not "
2240  "match init rank. input rank: "
2241  << inputRank
2242  << ", dimensions size: " << dimensionsRef.size()
2243  << ", init rank: " << initRank;
2244 
2245  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2246  if (dim < 0 || dim >= initRank)
2247  return emitOpError() << "dimension " << idx
2248  << " is out of range. expected range: [0, "
2249  << initRank - 1 << "], got: " << dim;
2250  }
2251 
2252  // Mapping from input dims to init dims.
2253  SmallVector<int64_t> dimMap;
2254  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2255  if (!llvm::is_contained(dimensionsRef, dim))
2256  dimMap.push_back(dim);
2257  }
2258 
2259  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2260  // This dimensions is mapped from the input. Init and input dims should
2261  // match.
2262  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2263  return emitOpError() << "input dim " << inputDimIdx
2264  << " should match init dim " << initDimIdx
2265  << ". input: " << inputShape[inputDimIdx]
2266  << ", init: " << initShape[initDimIdx];
2267  }
2268 
2269  return success();
2270 }
2271 
2272 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2273  int64_t rank = getInit().getType().getRank();
2274  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2275 }
2276 
2277 ArrayAttr BroadcastOp::getIndexingMaps() {
2278  Builder builder(getContext());
2279  int64_t rank = getInit().getType().getRank();
2280  return builder.getAffineMapArrayAttr(
2281  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2282  builder.getMultiDimIdentityMap(rank)});
2283 }
2284 
2285 void BroadcastOp::getEffects(
2287  &effects) {
2288  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2289 }
2290 
2291 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2292  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2293 }
2294 
2295 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2296  MLIRContext *context) {
2297  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2298 }
2299 
2300 //===----------------------------------------------------------------------===//
2301 // YieldOp
2302 //===----------------------------------------------------------------------===//
2303 
2305  if (getNumOperands() > 0)
2306  p << ' ' << getOperands();
2307  p.printOptionalAttrDict((*this)->getAttrs());
2308  if (getNumOperands() > 0)
2309  p << " : " << getOperandTypes();
2310 }
2311 
2312 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2314  SmallVector<Type, 2> types;
2315  SMLoc loc = parser.getCurrentLocation();
2316  return failure(parser.parseOperandList(opInfo) ||
2317  parser.parseOptionalAttrDict(result.attributes) ||
2318  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2319  parser.resolveOperands(opInfo, types, loc, result.operands));
2320 }
2321 
2322 // Check the operand number and types must match the element types of the
2323 // LinalgOp interface's shaped operands.
2324 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2325  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2326  return op.emitOpError("expected number of yield values (")
2327  << op.getNumOperands()
2328  << ") to match the number of inits / outs operands of the enclosing "
2329  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2330 
2331  for (OpOperand &opOperand : op->getOpOperands()) {
2332  OpOperand *outputOperand =
2333  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2334  Type elementType = outputOperand->get().getType();
2335  if (isa<MemRefType, RankedTensorType>(elementType))
2336  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2337  if (opOperand.get().getType() != elementType)
2338  return op.emitOpError("type of yield operand ")
2339  << (opOperand.getOperandNumber() + 1) << " ("
2340  << opOperand.get().getType() << ") doesn't match "
2341  << "the element type of the enclosing linalg.generic op ("
2342  << elementType << ")";
2343  }
2344  return success();
2345 }
2346 
2347 LogicalResult linalg::YieldOp::verify() {
2348  auto *parentOp = (*this)->getParentOp();
2349  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2350  return emitOpError("expected single non-empty parent region");
2351 
2352  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2353  return verifyYield(*this, linalgOp);
2354 
2355  return emitOpError("expected parent op with LinalgOp interface");
2356 }
2357 
2358 //===----------------------------------------------------------------------===//
2359 // IndexOp
2360 //===----------------------------------------------------------------------===//
2361 
2362 LogicalResult IndexOp::verify() {
2363  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2364  if (!linalgOp)
2365  return emitOpError("expected parent op with LinalgOp interface");
2366  if (linalgOp.getNumLoops() <= getDim())
2367  return emitOpError("expected dim (")
2368  << getDim() << ") to be lower than the number of loops ("
2369  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2370  return success();
2371 }
2372 
2373 OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2374  auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2375  // Bail out if `linalg.index` does not have a proper parent yet at this
2376  // point, e.g., when calling `createOrFold` during IR construction in
2377  // `genericOp::build`.
2378  if (!linalgOp)
2379  return OpFoldResult{};
2380 
2381  // Index of unit dims is always 0.
2382  SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2383  uint64_t dim = getDim();
2384  assert(dim < loopBounds.size() && "Dim is out of bounds");
2385  if (loopBounds[dim] == 1)
2387 
2388  return OpFoldResult{};
2389 }
2390 
2391 /////// Operations corresponding to library calls defined with Tablegen ////////
2392 
2393 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2394 
2395 #define GET_OP_CLASSES
2396 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2397 
2398 #define GET_OP_CLASSES
2399 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2400 #define GET_OP_CLASSES
2401 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2402 
2403 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2404  unsigned rank,
2405  MLIRContext *context) {
2406  if (maybeMap)
2407  return *maybeMap;
2408  if (rank == 0)
2409  return AffineMap::get(context);
2410  return AffineMap::getMultiDimIdentityMap(rank, context);
2411 }
2412 
2414 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2415  MLIRContext *context) {
2417  res.reserve(num);
2418  for (unsigned i = 0; i < num; ++i)
2419  res.push_back(getAffineDimExpr(startIdx++, context));
2420  return res;
2421 }
2422 
2425  auto rangeA = llvm::make_range(a.begin(), a.end());
2426  auto rangeB = llvm::make_range(b.begin(), b.end());
2427  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2428  return llvm::to_vector<4>(concatRanges);
2429 }
2430 
2431 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2432  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2433  ss << "view";
2434  for (auto size : memref.getShape())
2435  if (size < 0)
2436  ss << "sx";
2437  else
2438  ss << size << "x";
2439  if (failed(appendMangledType(ss, memref.getElementType())))
2440  return failure();
2441  if (auto as = memref.getMemorySpace()) {
2442  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2443  ss << "as" << attr.getInt();
2444  else
2445  return failure();
2446  }
2447  return success();
2448  }
2449  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2450  ss << "vector";
2451  llvm::interleave(
2452  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2453  if (failed(appendMangledType(ss, vec.getElementType())))
2454  return failure();
2455  return success();
2456  }
2457  if (t.isSignlessIntOrIndexOrFloat()) {
2458  ss << t;
2459  return success();
2460  }
2461  return failure();
2462 }
2463 
2465  assert(isa<LinalgOp>(op));
2466  std::string name(op->getName().getStringRef().str());
2467  std::string fun = "";
2468  for (NamedAttribute kv : op->getAttrs()) {
2469  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2470  fun = stringifyEnum(ufa.getValue()).str() + "_";
2471  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2472  fun = stringifyEnum(bfa.getValue()).str() + "_";
2473  }
2474  }
2475  name.reserve(128);
2476  llvm::replace(name, '.', '_');
2477  llvm::raw_string_ostream ss(name);
2478  ss << "_" << fun;
2479  for (Type t : op->getOperandTypes()) {
2480  if (failed(appendMangledType(ss, t)))
2481  return std::string();
2482  ss << "_";
2483  }
2484  name.pop_back();
2485  return name;
2486 }
2487 
2488 //===----------------------------------------------------------------------===//
2489 // Canonicalizers and Folders.
2490 //===----------------------------------------------------------------------===//
2491 
2492 namespace {
2493 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2495 
2496  LogicalResult matchAndRewrite(LinalgOp op,
2497  PatternRewriter &rewriter) const override {
2498  for (OpOperand &opOperand : op->getOpOperands()) {
2499  // Linalg "inputs" may be either tensor or memref type.
2500  // tensor<0xelt_type> is a convention that may not always mean
2501  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2502  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2503  if (!mt)
2504  continue;
2505  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2506  rewriter.eraseOp(op);
2507  return success();
2508  }
2509  }
2510  return failure();
2511  }
2512 };
2513 
2514 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2515 /// result that is more static than the linalg op.
2516 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2518 
2519  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2520  PatternRewriter &rewriter) const override {
2521  if (!tensor::canFoldIntoProducerOp(castOp))
2522  return failure();
2523 
2524  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2525  if (!linalgOp)
2526  return failure();
2527 
2528  // Cast can be in conditionally reachable region, if which case folding will
2529  // generate invalid code. Only conservatively fold ops in same block for
2530  // now.
2531  if (castOp->getBlock() != linalgOp->getBlock())
2532  return failure();
2533 
2534  OpBuilder::InsertionGuard guard(rewriter);
2535  rewriter.setInsertionPoint(linalgOp);
2536 
2537  Location loc = linalgOp.getLoc();
2538  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2539  unsigned resultNumber = resultValue.getResultNumber();
2540  auto resultType =
2541  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2542  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2543  // going from a more dynamic shape to a less dynamic shape. If the producer
2544  // for this cast, i.e. producer of the out operand, is also an operation
2545  // that folds with tensor.cast consumer (like this pattern), the cast will
2546  // continue to propagate as far up the stack as it can go.
2547  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2548  Value newOperand =
2549  tensor::CastOp::create(rewriter, loc, resultType, outOperand->get());
2550  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2551  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2552  linalgOp.getDpsInits().end());
2553  outputOperands[resultNumber] = newOperand;
2554  newOperands.append(outputOperands.begin(), outputOperands.end());
2555 
2556  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2557  linalgOp->result_type_end());
2558  resultTypes[resultNumber] = resultType;
2559  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2560 
2561  // Create a tensor.cast operation back to the original type.
2562  Value castBack = tensor::CastOp::create(
2563  rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber));
2564 
2565  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2566  results[resultNumber] = castBack;
2567  rewriter.replaceOp(linalgOp, results);
2568  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2569  return success();
2570  }
2571 };
2572 
2573 /// For each of the operand in `operands` this function maps the static sizes of
2574 /// dimensions to their affine dim expressions.
2575 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2576  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2577  for (OpOperand &opOperand : operands) {
2578  if (linalgOp.isScalar(&opOperand))
2579  continue;
2580  Value src = opOperand.get();
2581  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2582  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2583 
2584  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2585  // `tensor.cast` operation and source of the cast operation has a static
2586  // shape, then assign it to the `sourceShape`.
2587  auto *parentOp = src.getDefiningOp();
2588  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2589  if (parentOp) {
2590  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2591  Value castSource = castOp.getSource();
2592  auto castSourceType =
2593  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2594  if (castSourceType && castSourceType.hasStaticShape())
2595  sourceShape = castSourceType.getShape();
2596  }
2597  }
2598 
2599  // If the source shape's dimension has a static shape, map the affine dim
2600  // expression to the known static size.
2601  for (unsigned i = 0; i < sourceShape.size(); i++) {
2602  if (sourceType.isDynamicDim(i))
2603  continue;
2604  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2605  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2606  }
2607  }
2608 }
2609 
2610 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2611 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2612 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2613 /// change then `changeNeeded` is false and same operand is added in the
2614 /// `newOperands` list.
2615 static void createNewOperandWithStaticSizes(
2616  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2617  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2618  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2619  bool &changeNeeded) {
2620  Value src = opOperand->get();
2621  newOperands.push_back(src);
2622  if (linalgOp.isScalar(opOperand))
2623  return;
2624  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2625  Type resultType = sourceType;
2626  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2627  resultTypes.push_back(resultType);
2628  return;
2629  }
2630  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2631  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2632  SmallVector<int64_t> newShape;
2633  // If operand is updated with new shape, `newOperandNeeded` will be
2634  // true.
2635  bool newOperandNeeded = false;
2636  for (unsigned i = 0; i < sourceShape.size(); i++) {
2637  int64_t dimShape = sourceShape[i];
2638  AffineExpr dimExpr = sourceMap.getResult(i);
2639  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2640  newShape.push_back(dimShape);
2641  continue;
2642  }
2643  // Dimension has a dynamic shape and corresponding affine dim
2644  // expression is present in the map. So assign the size for the
2645  // given affine dim expression to the dimension.
2646  newShape.push_back(affineExprToSize[dimExpr]);
2647  newOperandNeeded = true;
2648  }
2649  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2650  sourceType.getEncoding());
2651  if (newOperandNeeded) {
2652  changeNeeded = true;
2653  // Get the new operand value given its size and element type by
2654  // casting it.
2655  Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src);
2656  unsigned index = opOperand->getOperandNumber();
2657  newOperands[index] = newOperand;
2658  }
2659  if (linalgOp.isDpsInit(opOperand))
2660  resultTypes.push_back(resultType);
2661 }
2662 
2663 /// Static shapes for the operands can be inferred if any one of the operands
2664 /// have a static shape. This can be done by referring to the affine dim
2665 /// expressions for the operand.
2666 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2668 
2669  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2670  PatternRewriter &rewriter) const override {
2671  if (!linalgOp.hasPureTensorSemantics())
2672  return failure();
2673 
2674  // Maps must be projected permutations.
2675  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2676  return !map.isProjectedPermutation();
2677  }))
2678  return failure();
2679 
2680  // Maps affine dim expressions to the static size of that dimension.
2681  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2682  Location loc = linalgOp.getLoc();
2683 
2684  // For each of the affine dim expression, check if the size is known. If
2685  // known add that in the map.
2686  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2687 
2688  SmallVector<Value> newOperands;
2689  SmallVector<Type> resultTypes;
2690 
2691  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2692  // change in their types.
2693  bool changeNeeded = false;
2694  newOperands.reserve(linalgOp->getNumOperands());
2695  resultTypes.reserve(linalgOp.getNumDpsInits());
2696 
2697  // Iterate over all the operands and update the static sizes.
2698  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2699  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2700  affineExprToSize, linalgOp, newOperands,
2701  resultTypes, changeNeeded);
2702  }
2703 
2704  // If the generic op has all the required static information, no
2705  // canonicalization needed.
2706  if (!changeNeeded)
2707  return failure();
2708 
2709  // Clone op.
2710  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2711  SmallVector<Value> replacements;
2712  replacements.reserve(newOp->getNumResults());
2713  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2714  Value newResult = std::get<1>(it);
2715  Value oldResult = std::get<0>(it);
2716  Type newType = newResult.getType();
2717  Type oldType = oldResult.getType();
2718  replacements.push_back(
2719  (newType != oldType)
2720  ? tensor::CastOp::create(rewriter, loc, oldType, newResult)
2721  : newResult);
2722  }
2723  rewriter.replaceOp(linalgOp, replacements);
2724  return success();
2725  }
2726 };
2727 
2728 } // namespace
2729 
2730 // All named ops canonicalizers and folders are auto-generated in the
2731 // .cpp.inc.
2732 
2733 //===----------------------------------------------------------------------===//
2734 // SoftmaxOp
2735 //===----------------------------------------------------------------------===//
2736 
2737 LogicalResult SoftmaxOp::verify() {
2738  ShapedType inputType = getInputOperandType();
2739  ShapedType outputType = getOutputOperandType();
2740 
2741  ArrayRef<int64_t> inputShape = inputType.getShape();
2742  ArrayRef<int64_t> outputShape = outputType.getShape();
2743  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2744  return emitOpError("incompatible output shape");
2745 
2746  int64_t inputRank = getInputOperandRank();
2747  int64_t dimension = getDimension();
2748  if ((dimension < 0) || (dimension >= inputRank))
2749  return emitOpError("incorrect dimension specified");
2750 
2751  return success();
2752 }
2753 
2754 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2755  int64_t operandRank = getInputOperandRank();
2756  SmallVector<Range> loopBounds(operandRank);
2757  Location loc = getLoc();
2758  Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
2759  Value one = arith::ConstantIndexOp::create(builder, loc, 1);
2760  Value source = getInput();
2761  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2762  loopBounds[dim].offset = zero;
2763  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2764  loopBounds[dim].stride = one;
2765  }
2766  return loopBounds;
2767 }
2768 
2769 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2770  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2771  utils::IteratorType::parallel);
2772  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2773  return iteratorTypes;
2774 }
2775 
2776 FailureOr<TilingResult>
2778  ArrayRef<OpFoldResult> offsets,
2779  ArrayRef<OpFoldResult> sizes) {
2780  int64_t rank = getInputOperandRank();
2781  auto oneAttr = builder.getI64IntegerAttr(1);
2782  SmallVector<OpFoldResult> strides(rank, oneAttr);
2783  SmallVector<Value> tiledOperands;
2784  Operation *inputSlice =
2785  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2786  if (!inputSlice) {
2787  return emitOpError("failed to compute input slice");
2788  }
2789  tiledOperands.emplace_back(inputSlice->getResult(0));
2790  Operation *outputSlice =
2791  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2792  if (!outputSlice) {
2793  return emitOpError("failed to compute output slice");
2794  }
2795  tiledOperands.emplace_back(outputSlice->getResult(0));
2796 
2797  SmallVector<Type, 4> resultTypes;
2798  if (hasPureTensorSemantics())
2799  resultTypes.push_back(tiledOperands[1].getType());
2800  Operation *tiledOp =
2801  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2802 
2803  return TilingResult{
2804  {tiledOp},
2805  SmallVector<Value>(tiledOp->getResults()),
2806  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2807 }
2808 
2809 LogicalResult SoftmaxOp::getResultTilePosition(
2810  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2811  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2812  SmallVector<OpFoldResult> &resultSizes) {
2813  if (resultNumber == 0) {
2814  resultOffsets.assign(offsets.begin(), offsets.end());
2815  resultSizes.assign(sizes.begin(), sizes.end());
2816  return success();
2817  }
2818  return failure();
2819 }
2820 
2821 // cast(dynamic) -> static.
2822 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2823  return memref::foldMemRefCast(*this);
2824 }
2825 
2826 LogicalResult
2828  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2830  Location loc = getOperation()->getLoc();
2831  IRRewriter rewriter(b);
2832  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2833  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2834  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2835  if (!outputShapedType.isDynamicDim(dim)) {
2836  // Static dim: Return IntegerAttr.
2837  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2838  } else {
2839  // Dynamic dim: Return Value.
2840  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2841  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2842  }
2843  }
2844  reifiedReturnShapes.emplace_back(std::move(shapes));
2845  return success();
2846 }
2847 
2848 void SoftmaxOp::getEffects(
2850  &effects) {
2851  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2852  if (!llvm::isa<MemRefType>(operand.getType()))
2853  continue;
2854  effects.emplace_back(MemoryEffects::Read::get(),
2855  &getOperation()->getOpOperand(index), /*stage=*/0,
2856  /*effectOnFullRegion=*/true,
2858  }
2859 
2860  for (OpOperand &operand : getDpsInitsMutable()) {
2861  if (!llvm::isa<MemRefType>(operand.get().getType()))
2862  continue;
2863  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2864  /*effectOnFullRegion=*/true,
2866  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2867  /*effectOnFullRegion=*/true,
2869  }
2870 }
2871 
2872 // Helper functions for softmax decomposition.
2873 // @{
2874 
2875 // Helper function to produce the iterator types (reduction or parallel) and
2876 // affine maps for the iterators used in the decomposition of softmax.
2877 // This method creates:
2878 // If allParallel == true:
2879 // - iterator type: {parallel, ..., parallel}
2880 // - affine maps:
2881 // -- identity with inputRank dimensions.
2882 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2883 // where N == inputRank.
2884 //
2885 // If allParallel == false:
2886 // - iterator type at dim(i) == parallel for i != \p dim and
2887 // dim(dim) == reduction.
2888 // - affine map:
2889 // -- identity with inputRank dimensions.
2890 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2891 // where N == inputRank.
2892 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2894  int64_t dim, bool allParallel = false) {
2895  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2896  utils::IteratorType::parallel);
2897  if (!allParallel)
2898  iteratorTypes[dim] = utils::IteratorType::reduction;
2899  MLIRContext *ctxt = builder.getContext();
2900  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2901  SmallVector<AffineExpr, 2> affineExprs;
2902  for (int i = 0; i < inputRank; i++) {
2903  if (i != dim)
2904  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2905  }
2906  auto reductionMap =
2907  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2908  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2909  return std::make_tuple(iteratorTypes, indexingMaps);
2910 }
2911 
2912 // Helper function to produce a linalg.generic that computes a reduction on
2913 // dimension \p dim with the operation type \p T.
2914 template <typename T>
2915 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2916  int64_t dim) {
2917  auto inputType = cast<ShapedType>(input.getType());
2918  ArrayRef<int64_t> inputShape = inputType.getShape();
2919  int64_t inputRank = inputShape.size();
2920  auto [iteratorTypes, indexingMaps] =
2921  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2922  assert(indexingMaps.size() == 2 &&
2923  "We should have two maps: 1 for the input, 1 for the output");
2924  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2925 
2926  auto genericOp = linalg::GenericOp::create(
2927  builder, loc, output.getType(), input, output, indexingMaps,
2928  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2929  Value result = T::create(b, loc, args[0], args[1]);
2930  linalg::YieldOp::create(b, loc, result);
2931  });
2932  return genericOp.getResult(0);
2933 }
2934 
2935 /// Produce a linalg generic that computes the second step of the softmax
2936 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2937 /// on dimension \p dim.
2938 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2939  Value max, Value output, int64_t dim) {
2940  auto inputType = cast<ShapedType>(input.getType());
2941  ArrayRef<int64_t> inputShape = inputType.getShape();
2942  int64_t inputRank = inputShape.size();
2943  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2944  builder, inputRank, dim, /*allParallel=*/true);
2945  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2946  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2947  // Add the affine map for the output argument.
2948  indexingMaps.push_back(indexingMaps[0]);
2949  auto genericOp = linalg::GenericOp::create(
2950  builder, loc, input.getType(), ValueRange{input, max}, output,
2951  indexingMaps, iteratorTypes,
2952  [&](OpBuilder &b, Location loc, ValueRange args) {
2953  Value diff = arith::SubFOp::create(b, loc, args[0], args[1]);
2954  Value result = math::ExpOp::create(b, loc, diff);
2955  linalg::YieldOp::create(b, loc, result);
2956  });
2957  return genericOp.getResult(0);
2958 }
2959 
2960 /// Produce a linalg generic that computes the final step of the softmax
2961 /// decomposition.
2962 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2963 /// yield n / d
2964 /// }
2965 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2966  Value denominator, Value output, int64_t dim) {
2967  auto inputType = cast<ShapedType>(numerator.getType());
2968  ArrayRef<int64_t> inputShape = inputType.getShape();
2969  int64_t inputRank = inputShape.size();
2970  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2971  builder, inputRank, dim, /*allParallel=*/true);
2972  assert(indexingMaps.size() == 2 &&
2973  "We should have one map for each input (2)");
2974  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2975  // Add the affine map for the output tensor.
2976  indexingMaps.push_back(indexingMaps[0]);
2977  auto genericOp = linalg::GenericOp::create(
2978  builder, loc, numerator.getType(), ValueRange{numerator, denominator},
2979  output, indexingMaps, iteratorTypes,
2980  [&](OpBuilder &b, Location loc, ValueRange args) {
2981  Value result = arith::DivFOp::create(b, loc, args[0], args[1]);
2982  linalg::YieldOp::create(b, loc, result);
2983  });
2984  return genericOp.getResult(0);
2985 }
2986 // @} End helper functions for softmax decomposition.
2987 
2988 /// Given an N-dimensional tensor x, this method converts
2989 /// softmax(x) to the following sequence of operations:
2990 ///
2991 /// 1. Compute the max of x along dimension d. This results
2992 /// in a N-1 dimensional tensor m.
2993 /// m = max(x, dim = d)
2994 ///
2995 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2996 /// a N dimensional tensor z.
2997 /// z = exp(x - m)
2998 ///
2999 /// 3. Compute the sum of z along dimension d. This results in
3000 /// a N-1 dimensional tensor l.
3001 /// l = sum(z, dim = d)
3002 ///
3003 /// 4. Divide z and l. This gives the N-dimensional softmax.
3004 /// softmax = z / l
3005 ///
3006 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3007  OpBuilder::InsertionGuard guard(b);
3008  b.setInsertionPoint(*this);
3009  Location loc = getLoc();
3010  Value input = getInput();
3011  ShapedType inputType = getInputOperandType();
3012  Type elementType = inputType.getElementType();
3013  int64_t reductionDim = getDimension();
3014  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
3015  Value output = getOutput();
3016  dims.erase(dims.begin() + reductionDim);
3017  // Step 1: Compute max along dim.
3018  Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType);
3019  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
3020  elementType, b, loc,
3021  /*useOnlyFiniteValue=*/true);
3022  Value neutralForMaxFInit =
3023  linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce)
3024  .result();
3025  Value max =
3026  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
3027 
3028  // Step 2: Subtract max from input and exponentiate.
3029  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
3030 
3031  // Step 3: Compute sum along dim.
3032  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
3033  b, loc, /*useOnlyFiniteValue=*/true);
3034  Value zeroInit =
3035  linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result();
3036  Value denominator =
3037  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
3038 
3039  // Step 4: Compute softmax.
3040  Value result =
3041  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
3042  return SmallVector<Value>{result};
3043 }
3044 
3045 //===----------------------------------------------------------------------===//
3046 // WinogradFilterTransformOp
3047 //===----------------------------------------------------------------------===//
3048 
3049 LogicalResult WinogradFilterTransformOp::verify() {
3050  auto filterType = cast<ShapedType>(getFilter().getType());
3051  ArrayRef<int64_t> filterShape = filterType.getShape();
3052  int64_t filterH = filterShape[getFilterHDim()];
3053  int64_t filterW = filterShape[getFilterWDim()];
3054  WinogradConv2DFmr fmr = getFmr();
3055  int64_t m, r;
3056  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3057 
3058  if (filterH != r && filterH != 1)
3059  return emitOpError("expect filter height either equals to r or 1");
3060  if (filterW != r && filterW != 1)
3061  return emitOpError("expect filter width either equals to r or 1");
3062  if (filterH == 1 && filterW == 1)
3063  return emitOpError("expect either filter height or width equals to r");
3064 
3065  SmallVector<int64_t> expectedOutputShape;
3066  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3067  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3068  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3069  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3070 
3071  auto outputType = cast<ShapedType>(getOutput().getType());
3072  ArrayRef<int64_t> outputShape = outputType.getShape();
3073  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3074  return emitOpError("the output shape is not expected");
3075  }
3076  return success();
3077 }
3078 
3080 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3081  Location loc = getLoc();
3082  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3083  IntegerAttr oneAttr = builder.getIndexAttr(1);
3084  Value filter = getFilter();
3085  int64_t filterRank = getFilterOperandRank();
3086  SmallVector<Range> loopBounds(filterRank);
3087  for (unsigned dim = 0; dim < filterRank; ++dim) {
3088  loopBounds[dim].offset = zeroAttr;
3089  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3090  loopBounds[dim].stride = oneAttr;
3091  }
3092  return loopBounds;
3093 }
3094 
3096 WinogradFilterTransformOp::getLoopIteratorTypes() {
3097  int64_t filterRank = getFilterOperandRank();
3098  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3099  utils::IteratorType::parallel);
3100  return iteratorTypes;
3101 }
3102 
3104  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3105  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3106  SmallVector<OpFoldResult> &resultSizes) {
3107  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3108  ShapedType filterType = getFilterOperandType();
3109  ArrayRef<int64_t> filterShape = filterType.getShape();
3110  int64_t filterH = filterShape[getFilterHDim()];
3111  int64_t filterW = filterShape[getFilterWDim()];
3112  WinogradConv2DFmr fmr = getFmr();
3113  int64_t m, r;
3114  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3115  int64_t alpha = m + r - 1;
3116  int64_t alphaH = filterH != 1 ? alpha : 1;
3117  int64_t alphaW = filterW != 1 ? alpha : 1;
3118  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3119  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3120 
3121  resultOffsets.append(
3122  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3123  resultSizes.append(
3124  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3125 
3126  return success();
3127 }
3128 
3129 /// Implement tiling for winograd_filter_transform
3130 /// The input of winograd_filter_transform is (F, KH, KW, C).
3131 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3132 /// Users can specify the tile sizes of F and C.
3133 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3134 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3136  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3137  ArrayRef<OpFoldResult> sizes) {
3138  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3139  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3140  ShapedType filterType = getFilterOperandType();
3141  ArrayRef<int64_t> filterShape = filterType.getShape();
3142  int64_t filterH = filterShape[getFilterHDim()];
3143  int64_t filterW = filterShape[getFilterWDim()];
3144  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3145  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3146  SmallVector<Value> tiledOperands;
3147  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3148 
3149  sliceOffsets.append(
3150  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3151  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3152  sizes[getFilterCDim()]});
3153  int64_t filterRank = getFilterOperandRank();
3154  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3155  Location loc = getLoc();
3156  auto filterSlice = tensor::ExtractSliceOp::create(
3157  builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3158  tiledOperands.emplace_back(filterSlice);
3159 
3160  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3161  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3162  resultSizes)))
3163  return failure();
3164 
3165  int64_t outputRank = getOutputOperandRank();
3166  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3167  auto outputSlice = tensor::ExtractSliceOp::create(
3168  builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3169  tiledOperands.emplace_back(outputSlice);
3170 
3171  SmallVector<Type> resultTypes;
3172  resultTypes.push_back(tiledOperands[1].getType());
3173  Operation *tiledOp =
3174  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3175 
3176  return TilingResult{
3177  {tiledOp},
3178  SmallVector<Value>(tiledOp->getResults()),
3179  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3180 }
3181 
3182 //===----------------------------------------------------------------------===//
3183 // WinogradInputTransformOp
3184 //===----------------------------------------------------------------------===//
3185 
3186 LogicalResult WinogradInputTransformOp::verify() {
3187  auto inputType = cast<ShapedType>(getInput().getType());
3188  ArrayRef<int64_t> inputShape = inputType.getShape();
3189  int64_t inputH = inputShape[getInputHDim()];
3190  int64_t inputW = inputShape[getInputWDim()];
3191  WinogradConv2DFmr fmr = getFmr();
3192  int64_t m, r;
3193  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3194  int64_t tileSize = m + r - 1;
3195 
3196  auto outputType = cast<ShapedType>(getOutput().getType());
3197  ArrayRef<int64_t> outputShape = outputType.getShape();
3198  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3199  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3200 
3201  SmallVector<int64_t> expectedOutputShape(6, inputH);
3202  if (ShapedType::isDynamic(inputH)) {
3203  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3204  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3205  } else {
3206  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3207  expectedOutputShape[getOutputTileHDim()] =
3208  leftTransform ? (inputH - (r - 1)) / m : inputH;
3209  }
3210  if (ShapedType::isDynamic(inputW)) {
3211  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3212  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3213  } else {
3214  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3215  expectedOutputShape[getOutputTileWDim()] =
3216  rightTransform ? (inputW - (r - 1)) / m : inputW;
3217  }
3218  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3219  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3220 
3221  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3222  return emitOpError("the output shape is not expected");
3223  }
3224  return success();
3225 }
3226 
3228 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3229  Location loc = getLoc();
3230  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3231  IntegerAttr oneAttr = builder.getIndexAttr(1);
3232  Value output = getOutput();
3233  int64_t outputRank = getOutputOperandRank();
3234  SmallVector<Range> loopBounds(outputRank);
3235  for (unsigned dim = 0; dim < outputRank; ++dim) {
3236  loopBounds[dim].offset = zeroAttr;
3237  // alphaH, alphaW, tileH, tileW, N, C
3238  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3239  loopBounds[dim].stride = oneAttr;
3240  }
3241  return loopBounds;
3242 }
3243 
3245 WinogradInputTransformOp::getLoopIteratorTypes() {
3246  int64_t outputRank = getOutputOperandRank();
3247  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3248  utils::IteratorType::parallel);
3249  return iteratorTypes;
3250 }
3251 
3253  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3254  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3255  SmallVector<OpFoldResult> &resultSizes) {
3256  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3257  ShapedType outputType = getOutputOperandType();
3258  ArrayRef<int64_t> outputShape = outputType.getShape();
3259  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3260  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3261 
3262  WinogradConv2DFmr fmr = getFmr();
3263  int64_t m, r;
3264  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3265  int64_t alpha = m + r - 1;
3266  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3267  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3268 
3269  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3270  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3271 
3272  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3273  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3274  offsets[getOutputCDim()]});
3275  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3276  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3277  sizes[getOutputCDim()]});
3278 
3279  return success();
3280 }
3281 
3282 /// Implement tiling for winograd_input_transform
3283 /// The input of winograd_input_transform is (N, H, W, C).
3284 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3285 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3286 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3287 /// the values for the sizes of tileH, tileW, N, C for one tile.
3288 FailureOr<TilingResult>
3290  ArrayRef<OpFoldResult> offsets,
3291  ArrayRef<OpFoldResult> sizes) {
3292  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3293  WinogradConv2DFmr fmr = getFmr();
3294  int64_t m, r;
3295  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3296 
3297  ShapedType outputType = getOutputOperandType();
3298  ArrayRef<int64_t> outputShape = outputType.getShape();
3299  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3300  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3301 
3302  Location loc = getLoc();
3303  MLIRContext *context = builder.getContext();
3304  auto identityAffineMap =
3305  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3306  auto offsetAffineMap =
3307  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3308  Value mappedOffsetH = affine::makeComposedAffineApply(
3309  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3310  offsets[getOutputTileHDim()]);
3311  Value mappedOffsetW = affine::makeComposedAffineApply(
3312  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3313  offsets[getOutputTileWDim()]);
3314  auto sizeAffineMap = AffineMap::get(
3315  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3316  Value mappedSizeH = affine::makeComposedAffineApply(
3317  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3318  Value mappedSizeW = affine::makeComposedAffineApply(
3319  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3320 
3321  SmallVector<Value> tiledOperands;
3322  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3323 
3324  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3325  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3326  sliceOffsets.append(
3327  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3328  OpFoldResult sizeH =
3329  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3330  OpFoldResult sizeW =
3331  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3332  sliceSizes.append(
3333  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3334  int64_t inputRank = getInputOperandRank();
3335  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3336  auto inputSlice = tensor::ExtractSliceOp::create(
3337  builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3338  tiledOperands.emplace_back(inputSlice);
3339 
3340  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3341  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3342  resultSizes)))
3343  return failure();
3344 
3345  int64_t outputRank = getOutputOperandRank();
3346  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3347  auto outputSlice = tensor::ExtractSliceOp::create(
3348  builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3349  tiledOperands.emplace_back(outputSlice);
3350 
3351  SmallVector<Type> resultTypes;
3352  resultTypes.push_back(tiledOperands[1].getType());
3353  Operation *tiledOp =
3354  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3355 
3356  return TilingResult{
3357  {tiledOp},
3358  SmallVector<Value>(tiledOp->getResults()),
3359  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3360 }
3361 
3362 //===----------------------------------------------------------------------===//
3363 // WinogradOutputTransformOp
3364 //===----------------------------------------------------------------------===//
3365 
3366 LogicalResult WinogradOutputTransformOp::verify() {
3367  auto valueType = cast<ShapedType>(getValue().getType());
3368  ArrayRef<int64_t> valueShape = valueType.getShape();
3369  int64_t valueH = valueShape[getValueAlphaHDim()];
3370  int64_t valueW = valueShape[getValueAlphaWDim()];
3371  int64_t valueTileH = valueShape[getValueTileHDim()];
3372  int64_t valueTileW = valueShape[getValueTileWDim()];
3373  WinogradConv2DFmr fmr = getFmr();
3374  int64_t m, r;
3375  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3376  bool leftTransform = valueH != 1;
3377  bool rightTransform = valueW != 1;
3378 
3379  int64_t outputRank = getOutputOperandRank();
3380  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3381  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3382  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3383  } else {
3384  if (valueH != (leftTransform ? m + r - 1 : 1))
3385  return emitOpError("expect input height equals to input tile size");
3386  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3387  }
3388  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3389  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3390  } else {
3391  if (valueW != (rightTransform ? m + r - 1 : 1))
3392  return emitOpError("expect input width equals to input tile size");
3393  expectedOutputShape[getOutputWDim()] =
3394  (rightTransform ? m : 1) * valueTileW;
3395  }
3396  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3397  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3398 
3399  auto outputType = cast<ShapedType>(getOutput().getType());
3400  ArrayRef<int64_t> outputShape = outputType.getShape();
3401  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3402  return emitOpError("the output shape is not expected");
3403  }
3404  return success();
3405 }
3406 
3408 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3409  Location loc = getLoc();
3410  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3411  IntegerAttr oneAttr = builder.getIndexAttr(1);
3412  Value value = getValue();
3413  int64_t valueRank = getValueOperandRank();
3414  SmallVector<Range> loopBounds(valueRank);
3415  for (unsigned dim = 0; dim < valueRank; ++dim) {
3416  loopBounds[dim].offset = zeroAttr;
3417  // alphaH, alphaW, tileH, tileW, N, F
3418  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3419  loopBounds[dim].stride = oneAttr;
3420  }
3421  return loopBounds;
3422 }
3423 
3425 WinogradOutputTransformOp::getLoopIteratorTypes() {
3426  int64_t valueRank = getValueOperandRank();
3427  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3428  utils::IteratorType::parallel);
3429  return iteratorTypes;
3430 }
3431 
3433  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3434  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3435  SmallVector<OpFoldResult> &resultSizes) {
3436  WinogradConv2DFmr fmr = getFmr();
3437  int64_t m, r;
3438  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
3439 
3440  Location loc = getLoc();
3441  MLIRContext *context = builder.getContext();
3442  auto identityAffineMap =
3443  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3444  auto affineMap =
3445  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3446 
3447  ShapedType valueType = getValueOperandType();
3448  ArrayRef<int64_t> valueShape = valueType.getShape();
3449  int64_t valueH = valueShape[0];
3450  int64_t valueW = valueShape[1];
3451  Value mappedOffsetH = affine::makeComposedAffineApply(
3452  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3453  offsets[getValueTileHDim()]);
3454  Value mappedOffsetW = affine::makeComposedAffineApply(
3455  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3456  offsets[getValueTileWDim()]);
3457  Value mappedSizeH = affine::makeComposedAffineApply(
3458  builder, loc, affineMap, sizes[getValueTileHDim()]);
3459  Value mappedSizeW = affine::makeComposedAffineApply(
3460  builder, loc, affineMap, sizes[getValueTileWDim()]);
3461 
3462  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3463  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3464  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3465  OpFoldResult sizeH =
3466  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3467  OpFoldResult sizeW =
3468  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3469 
3470  resultOffsets.append(
3471  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3472  resultSizes.append(
3473  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3474  return success();
3475 }
3476 
3477 /// Implement tiling for winograd_output_transform
3478 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3479 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3480 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3481 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3482 /// for the sizes of tileH, tileW, N, F for one tile.
3484  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3485  ArrayRef<OpFoldResult> sizes) {
3486  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3487  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3488  Location loc = getLoc();
3489  SmallVector<Value> tiledOperands;
3490  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3491 
3492  ShapedType valueType = getValueOperandType();
3493  ArrayRef<int64_t> valueShape = valueType.getShape();
3494  int64_t alphaH = valueShape[getValueAlphaHDim()];
3495  int64_t alphaW = valueShape[getValueAlphaWDim()];
3496  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3497  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3498 
3499  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3500  offsets[getValueTileWDim()], offsets[getValueNDim()],
3501  offsets[getValueFDim()]});
3502  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3503  sizes[getValueTileWDim()], sizes[getValueNDim()],
3504  sizes[getValueFDim()]});
3505  int64_t valueRank = getValueOperandRank();
3506  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3507  auto valueSlice = tensor::ExtractSliceOp::create(
3508  builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3509  tiledOperands.emplace_back(valueSlice);
3510 
3511  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3512  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3513  resultSizes)))
3514  return failure();
3515 
3516  int64_t outputRank = getOutputOperandRank();
3517  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3518  auto outputSlice = tensor::ExtractSliceOp::create(
3519  builder, loc, getOutput(), resultOffsets, resultSizes, strides);
3520  tiledOperands.emplace_back(outputSlice);
3521 
3522  SmallVector<Type> resultTypes;
3523  resultTypes.push_back(tiledOperands[1].getType());
3524  Operation *tiledOp =
3525  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3526 
3527  return TilingResult{
3528  {tiledOp},
3529  SmallVector<Value>(tiledOp->getResults()),
3530  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3531 }
3532 
3533 //===----------------------------------------------------------------------===//
3534 // LinalgDialect
3535 // TODO: Merge with the LinalgDialect block at the bottom
3536 //===----------------------------------------------------------------------===//
3537 
3538 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3539 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3540  auto explicitRange = subMap.getResults();
3541  auto defaultRange = fullMap.getResults();
3542  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3543  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3544  llvm::set_union(explicitSet, defaultSet);
3545  return explicitSet == defaultSet;
3546 }
3547 
3548 /// Check if the user defined map is valid broadcast map. Here broadcast
3549 /// indexing maps are defined in context of corresponding default indexing maps
3550 /// for the given Op. This way the check becomes very simple i.e just check the
3551 /// number of result dims.
3552 /// Returns true if the explictMap is broadcasted with respect to the
3553 /// defaultMap.
3554 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3555  return explictMap.getNumResults() < defaultMap.getNumResults();
3556 }
3557 
3558 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3559 /// indexing map for the MatmulOp \p op for each operand specified by \p
3560 /// opIndex.
3561 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3562  unsigned opIndex) {
3563  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3564  SmallVector<AffineMap, 3> defaultIndexingMaps =
3565  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3566 
3567  auto opIndexingMap = opIndexingMaps[opIndex];
3568  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3569  // Check general validity of indexing map results.
3570  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3571  return matmulOp->emitOpError()
3572  << "Unexpected dim expression in map result.";
3573 
3574  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3575  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3576  return matmulOp->emitOpError()
3577  << "Invalid broadcast requested, should be (d2).";
3578  }
3579  return success();
3580  }
3581  return success();
3582 }
3583 
3584 // Check general validity of input indexing map of
3585 // BatchMatmulOp/BatchReduceMatmulOp.
3586 template <typename OpTy>
3587 static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3588  AffineMap opIndexingMap,
3589  AffineMap defaultIndexingMap, bool isLHS) {
3590  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3591  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3592  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3593  // Check the result dims are valid.
3594  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3595  return batchVariantMatmulOp->emitOpError()
3596  << "Unexpected result dim expression (outside the set of default "
3597  "result dims).";
3598 
3599  // Check for valid number of result dims of input maps.
3600  if (opIndexingMap.getNumResults() > 3)
3601  return batchVariantMatmulOp->emitOpError()
3602  << "no. of result dim expressions exceeds 3.";
3603 
3604  auto hasValidBatchDim = [](AffineMap map) {
3605  AffineExpr batchDim = map.getResult(0);
3606  return batchDim.isFunctionOfDim(0);
3607  };
3608 
3609  // Check if the requested broadcast is valid.
3610  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3611  if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3612  return batchVariantMatmulOp->emitOpError()
3613  << "Invalid broadcast requested.";
3614  } else if (!hasValidBatchDim(opIndexingMap)) {
3615  return batchVariantMatmulOp->emitOpError()
3616  << "Invalid batch dimension expression.";
3617  }
3618  return success();
3619 }
3620 
3621 /// This function checks if the given AffineMap for the output of a
3622 /// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3623 /// dimensions and if the output map result dimensions are valid.
3624 template <typename OpTy>
3625 static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3626  AffineMap opIndexingMap) {
3627  assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3628  isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3629  "Expected BatchMatmulOp or BatchReduceMatmulOp");
3630  if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3631  opIndexingMap.getNumResults() != 3) {
3632 
3633  return batchVariantMatmulOp->emitOpError()
3634  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3635  << ").";
3636  }
3637  if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3638  opIndexingMap.getNumResults() != 2) {
3639  return batchVariantMatmulOp->emitOpError()
3640  << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3641  << ").";
3642  }
3643 
3644  auto areValidOutputResultDim = [&](AffineMap outputMap) {
3645  return isa<BatchMatmulOp>(batchVariantMatmulOp)
3646  ? outputMap.getResult(0).isFunctionOfDim(0) &&
3647  outputMap.getResult(1).isFunctionOfDim(1) &&
3648  outputMap.getResult(2).isFunctionOfDim(2)
3649  : outputMap.getResult(0).isFunctionOfDim(1) &&
3650  outputMap.getResult(1).isFunctionOfDim(2);
3651  };
3652 
3653  if (!areValidOutputResultDim(opIndexingMap)) {
3654  return batchVariantMatmulOp->emitOpError()
3655  << "Invalid output map result dimension.";
3656  }
3657 
3658  return success();
3659 }
3660 
3661 /// Verifies the broadcast and transpose semantic specified by the explicit
3662 /// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3663 /// specified by opIndex.
3664 template <typename OpTy>
3665 static LogicalResult
3667  unsigned opIndex) {
3668  SmallVector<AffineMap, 3> opIndexingMaps =
3669  batchVariantMatmulOp.getIndexingMapsArray();
3670  SmallVector<AffineMap, 3> defaultIndexingMaps =
3671  batchVariantMatmulOp.getDefaultIndexingMaps(
3672  batchVariantMatmulOp->getContext());
3673 
3674  if (opIndexingMaps.size() != 3)
3675  return batchVariantMatmulOp->emitOpError()
3676  << "Indexing_map attribute must have 3 affine maps.";
3677 
3678  auto opIndexingMap = opIndexingMaps[opIndex];
3679  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3680 
3681  if (opIndex == 2 &&
3682  failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3683  return failure();
3684 
3685  if (opIndex != 2 &&
3686  failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3687  defaultIndexingMap, opIndex == 0)))
3688  return failure();
3689 
3690  return success();
3691 }
3692 
3693 namespace mlir {
3694 namespace linalg {
3695 
3696 std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3697  if (m == 2 && r == 3)
3698  return WinogradConv2DFmr::F_2_3;
3699  if (m == 4 && r == 3)
3700  return WinogradConv2DFmr::F_4_3;
3701  if (m == 2 && r == 5)
3702  return WinogradConv2DFmr::F_2_5;
3703  return std::nullopt;
3704 }
3705 
3706 std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3707  switch (fmr) {
3708  case WinogradConv2DFmr::F_2_3:
3709  return {2, 3};
3710  case WinogradConv2DFmr::F_4_3:
3711  return {4, 3};
3712  case WinogradConv2DFmr::F_2_5:
3713  return {2, 5};
3714  }
3715 }
3716 
3717 //===----------------------------------------------------------------------===//
3718 // MatMulOp
3719 //===----------------------------------------------------------------------===//
3720 
3721 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3722 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3723  AffineExpr d0, d1, d2;
3724  SmallVector<AffineMap> indexingMaps;
3725  bindDims(context, d0, d1, d2);
3726  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3727  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3728  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3729  return indexingMaps;
3730 }
3731 
3732 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3733  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3734  utils::IteratorType::parallel,
3735  utils::IteratorType::reduction};
3736 }
3737 
3738 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3739 
3740 std::string MatmulOp::getLibraryCallName() {
3741  return generateLibraryCallName(getOperation());
3742 }
3743 
3744 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3745 
3746 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3747 /// the user defined indexing maps are not equal to default map.
3748 bool MatmulOp::hasUserDefinedMaps() {
3749  SmallVector<AffineMap, 3> defaultMaps =
3750  getDefaultIndexingMaps(this->getContext());
3751  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3752  return defaultMaps != explicitMaps;
3753 }
3754 
3755 /// Implements the block region builder for the MatmulOp. This is called by
3756 /// 'fillStructuredOpRegion'.
3757 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3760  if (emitError && block.getNumArguments() != 3) {
3761  emitError() << "MatmulOp regionBuilder expects 3 args, got "
3762  << block.getNumArguments();
3763  return;
3764  }
3765  assert(block.getNumArguments() == 3 &&
3766  "MatmulOp regionBuilder expects 3 args");
3767  RegionBuilderHelper helper(b, block);
3768  SmallVector<Value> yields;
3769 
3770  TypeFn castVal = TypeFn::cast_signed;
3771  const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3772  return attr.getName() == "cast";
3773  });
3774  if (castIter != attrs.end()) {
3775  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3776  castVal = attr.getValue();
3777  }
3778 
3779  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3780  block.getArgument(0));
3781  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3782  block.getArgument(1));
3783  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2, emitError);
3784  if (!value3)
3785  return;
3786  Value value4 = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3787  value3, emitError);
3788  if (!value4)
3789  return;
3790  yields.push_back(value4);
3791  helper.yieldOutputs(yields);
3792 }
3793 
3794 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
3795 /// broadcast map must include K dimension.
3796 /// TODO: Strict inclusion of K dimension in the broadcast map is not
3797 /// necessary for both input matrices simultaneously. We can relax this
3798 /// condition to have K dimension for one input matrix map and infer the K
3799 /// dimension for other input matrix map from the one already having K
3800 /// dimension.
3801 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3802  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3803  AffineExpr expr = bcastMap.getResult(0);
3804  // Invalid map if the common dimension of matmul not found.
3805  return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3806 }
3807 
3808 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3809  if (parser.parseOptionalKeyword("indexing_maps"))
3810  return ArrayAttr{
3811  nullptr}; // Success in case indexing_maps was not provided.
3812 
3813  ArrayAttr arrayAttr;
3814  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3815  return failure();
3816 
3817  if (llvm::any_of(arrayAttr,
3818  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3819  return parser.emitError(parser.getCurrentLocation())
3820  << "element of indexing_maps array is not an affine_map";
3821 
3822  return arrayAttr;
3823 }
3824 
3825 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3826  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3827  if (failed(indexingMapsAttr))
3828  return failure();
3829 
3830  if (*indexingMapsAttr == nullptr) {
3831  auto indexingMapAttrs = llvm::map_to_vector(
3832  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3833  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3834  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3835  }
3836 
3837  result.addAttribute("indexing_maps", *indexingMapsAttr);
3838  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3839  MatmulOp::getRegionBuilder());
3840 }
3841 
3842 void MatmulOp::print(OpAsmPrinter &p) {
3843  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3844  MatmulOp::getDefaultIndexingMaps(getContext()),
3845  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3846  if (!llvm::equal(getIndexingMaps(), indexingMaps))
3847  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3848 
3849  std::array<StringRef, 3> elidedAttrs = {
3850  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3851  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3852  elidedAttrs);
3853 }
3854 
3855 /// Verify the user defined indexing maps.
3856 LogicalResult MatmulOp::verify() {
3857  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3858  if (!hasUserDefinedMaps())
3859  return success();
3860 
3861  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3862  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3863  return failure();
3864  }
3865  return success();
3866 }
3867 
3868 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3869  return memref::foldMemRefCast(*this);
3870 }
3871 
3872 void MatmulOp::getEffects(
3874  &effects) {
3875  if (hasPureTensorSemantics())
3876  return;
3877  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3878 }
3879 
3880 Speculation::Speculatability MatmulOp::getSpeculatability() {
3881  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3882 }
3883 
3884 //===----------------------------------------------------------------------===//
3885 // ContractOp
3886 //===----------------------------------------------------------------------===//
3887 
3888 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3889  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3890  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3891  // domains are all the same, and each implements a projected permutation.
3892  // Each iteration space dim must occur for at least one operand and either
3893  // takes part in a contraction/reduction or else has parallel iteration type.
3894  // We have that a dim is a contraction/reduction dim if and only if the dim
3895  // occurs for the output operand. We use this fact for fast inference:
3896  // NB: In case we allow dims to occur solely for one input, the above still
3897  // holds: per the einsum semantics, these are reduction dims as well.
3898  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3899  for (auto result : outAffineMap.getResults()) {
3900  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3901  assert(dimExpr && "affine_map is a projected permutation");
3902  dimsInOutput[dimExpr.getPosition()] = true;
3903  }
3904 
3905  SmallVector<utils::IteratorType> iteratorTypes;
3906  for (auto dimOccursInOutput : dimsInOutput)
3907  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3908  : utils::IteratorType::reduction);
3909 
3910  return iteratorTypes;
3911 }
3912 
3913 unsigned ContractOp::getNumRegionArgs() { return 3; }
3914 
3915 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3916 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3919  if (emitError && block.getNumArguments() != 3) {
3920  emitError() << "ContractOp regionBuilder expects 3 args, got "
3921  << block.getNumArguments();
3922  return;
3923  }
3924  assert(block.getNumArguments() == 3 &&
3925  "ContractOp regionBuilder expects 3 args");
3926  RegionBuilderHelper helper(b, block);
3927 
3928  TypeFn castSignedness = TypeFn::cast_signed;
3929  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3930  return attr.getName() == "cast";
3931  });
3932  if (castIter != attrs.end()) {
3933  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3934  castSignedness = attr.getValue();
3935  }
3936 
3937  // TODO: Support fields with operators besides mult & add.
3938  Type outType = block.getArgument(2).getType();
3939  Value lhsAtOutType =
3940  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3941  Value rhsAtOutType =
3942  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3943  Value productAtOutType = helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType,
3944  rhsAtOutType, emitError);
3945  if (!productAtOutType)
3946  return;
3947  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3948  productAtOutType, emitError);
3949  if (!result)
3950  return;
3951  helper.yieldOutputs({result});
3952 }
3953 
3954 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3955  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3956  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3957  return parser.emitError(parser.getCurrentLocation(),
3958  "expected 'indexing_maps' attribute");
3959  result.addAttribute("indexing_maps", *indexingMapsAttr);
3960 
3961  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3962  regionBuilder);
3963 }
3964 
3966  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3968  p, getOperation(), getInputs(), getOutputs(),
3969  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3970 }
3971 
3972 LogicalResult ContractOp::verify() {
3973  int iterationSpaceDims = -1;
3974  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3975  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3976  // access an input operand (so occurrence count can be at most 2) and
3977  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3978  SmallVector<size_t> inOccurrences;
3979  SmallVector<size_t> outOccurrences;
3980 
3981  // A helper so that for each operand's affine_map and type we check that ...
3982  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3983  bool isInput) -> LogicalResult {
3984  // ... the affine_map is a projected permutation;
3985  if (!affineMap.isProjectedPermutation())
3986  return emitError("provided affine_map is not a projected permutation");
3987 
3988  // ... the rank of the affine_map's results and corresponding type match;
3989  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3990  if (affineMap.getNumResults() != shapedType.getRank())
3991  return emitError("ranks of shaped operand and results of corresponding "
3992  "affine_map differ");
3993  } else if (affineMap.getNumResults() != 0) {
3994  return emitError("affine_map specifies shaped access while operand has "
3995  "non-shaped type");
3996  }
3997 
3998  // ... the rank of the affine_map's domain is the same as those seen prior;
3999  if (iterationSpaceDims == -1) {
4000  iterationSpaceDims = affineMap.getNumDims();
4001  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4002  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4003  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4004  return emitError("iteration spaces of provided affine_maps differ");
4005  }
4006 
4007  // ... update counts of dims used to access either an input or the output.
4008  for (AffineExpr affineExpr : affineMap.getResults()) {
4009  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
4010  if (!affineDimExpr)
4011  llvm_unreachable("affine_map is a projected permutation");
4012 
4013  if (isInput)
4014  inOccurrences[affineDimExpr.getPosition()] += 1;
4015  else
4016  outOccurrences[affineDimExpr.getPosition()] += 1;
4017  }
4018 
4019  return success();
4020  };
4021 
4022  for (auto &&[affineMap, operandType, isInput] :
4023  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
4024  SmallVector<bool>{true, true, false})) {
4025  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
4026  return failure(); // NB: checkAffineMapAndType will emit relevant error.
4027  }
4028 
4029  bool hasContractingDim = false;
4030  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4031  size_t inOccCount = inOccurrences[dimIndex];
4032  size_t outOccCount = outOccurrences[dimIndex];
4033 
4034  // We have a contracting dim if and only if ...
4035  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4036 
4037  if (inOccCount == 0 && outOccCount == 0)
4038  return emitError() << "iteration space dim at index " << dimIndex
4039  << " not used to access any operand";
4040 
4041  // NB: We disallow a dim which occurs for only one input operand and not
4042  // for the output. In terms of einsum semantics such dims have a
4043  // sensible meaning - namely an additional reduction per each such dim.
4044  // By contrast, the ContractionOpInterface does not know about this
4045  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4046  // while vector.contract's verifier accepts dims of this kind many of
4047  // its lowerings give up on encountering these dims.
4048  // TODO: Remove following once we have comprehensive support for input-only
4049  // reduction dims, at both the linalg- and vector-dialect levels.
4050  if (inOccCount == 1 && outOccCount != 1)
4051  return emitError()
4052  << "iteration space dim at index " << dimIndex
4053  << " is neither a contracting dim nor of parallel iteration type";
4054  }
4055 
4056  if (!hasContractingDim)
4057  return emitError("'indexing_maps' do not specify a contracting dimension");
4058 
4059  return success();
4060 }
4061 
4062 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4063  return memref::foldMemRefCast(*this);
4064 }
4065 
4066 void ContractOp::getEffects(
4068  &effects) {
4069  if (hasPureTensorSemantics())
4070  return;
4071  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4072 }
4073 
4074 Speculation::Speculatability ContractOp::getSpeculatability() {
4075  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4076 }
4077 
4078 //===----------------------------------------------------------------------===//
4079 // Implementation of BatchMatmulOp
4080 //===----------------------------------------------------------------------===//
4082 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4083  AffineExpr d0, d1, d2, d3;
4084  SmallVector<AffineMap> indexingMaps;
4085  bindDims(context, d0, d1, d2, d3);
4086  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
4087  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
4088  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
4089  return indexingMaps;
4090 }
4091 
4092 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4094  utils::IteratorType::parallel, utils::IteratorType::parallel,
4095  utils::IteratorType::parallel, utils::IteratorType::reduction};
4096 }
4097 
4098 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4099 
4100 std::string BatchMatmulOp::getLibraryCallName() {
4101  return generateLibraryCallName(getOperation());
4102 }
4103 
4104 /// Check if the op has broadcast and/or transpose semantic. Returns true if
4105 /// the user defined indexing maps are not equal to default map.
4106 bool BatchMatmulOp::hasUserDefinedMaps() {
4107  SmallVector<AffineMap, 3> defaultMaps =
4108  getDefaultIndexingMaps(this->getContext());
4109  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4110  return defaultMaps != explicitMaps;
4111 }
4112 
4113 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
4114 /// broadcast map must include K dimension.
4115 /// TODO: Strict inclusion of K dimension in the broadcast map is not
4116 /// necessary for both input matrices simultaneously. We can relax this
4117 /// condition to have K dimension for one input matrix map and infer the K
4118 /// dimension for other input matrix map from the one already having K
4119 /// dimension.
4120 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4121  assert(bcastMap.getNumResults() < 3 &&
4122  "Expected less than 3 result dim expr.");
4123  bool isValid = false;
4124  enum Indices { batchPos, mPos, nPos, kPos };
4125  if (bcastMap.getNumResults() == 1) {
4126  AffineExpr expr = bcastMap.getResult(0);
4127  isValid = expr.isFunctionOfDim(kPos);
4128  } else if (bcastMap.getNumResults() == 2) {
4129  AffineExpr expr0 = bcastMap.getResult(0);
4130  AffineExpr expr1 = bcastMap.getResult(1);
4131  isValid =
4132  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4133  expr0.isFunctionOfDim(mPos)) &&
4134  expr1.isFunctionOfDim(kPos))
4135  : ((expr0.isFunctionOfDim(batchPos) &&
4136  expr1.isFunctionOfDim(kPos)) ||
4137  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4138  }
4139  return isValid;
4140 }
4141 
4142 void BatchMatmulOp::regionBuilder(
4145  if (emitError && block.getNumArguments() != 3) {
4146  emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4147  << block.getNumArguments();
4148  return;
4149  }
4150  assert(block.getNumArguments() == 3 &&
4151  "BatchMatmulOp regionBuilder expects 3 args");
4152  RegionBuilderHelper helper(b, block);
4153  SmallVector<Value> yields;
4154 
4155  TypeFn castVal = TypeFn::cast_signed;
4156  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4157  return attr.getName() == "cast";
4158  });
4159  if (castIter != attrs.end()) {
4160  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4161  castVal = attr.getValue();
4162  }
4163 
4164  auto toType = block.getArgument(2).getType();
4165  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4166  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4167  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4168  Value addVal =
4169  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4170  yields.push_back(addVal);
4171  helper.yieldOutputs(yields);
4172 }
4173 
4174 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4175  SmallVector<Attribute, 3> indexingMapsAttr;
4176  Attribute mapAttr;
4177  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4178  if (parser.parseEqual())
4179  return failure();
4180 
4181  if (parser.parseLSquare())
4182  return failure();
4183 
4184  do {
4185  if (parser.parseAttribute(mapAttr))
4186  return failure();
4187  if (!isa<AffineMapAttr>(mapAttr)) {
4188  return parser.emitError(parser.getCurrentLocation(),
4189  "expected affine map attribute");
4190  }
4191  indexingMapsAttr.push_back(mapAttr);
4192 
4193  if (parser.parseOptionalComma())
4194  break;
4195  } while (true);
4196 
4197  if (parser.parseRSquare())
4198  return failure();
4199  }
4200  // Initialize indexingMaps, if not supplied explicitly.
4201  if (indexingMapsAttr.empty()) {
4202  indexingMapsAttr = llvm::map_to_vector(
4203  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4204  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4205  }
4206  result.addAttribute("indexing_maps",
4207  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4208 
4209  return ::parseNamedStructuredOp(parser, result,
4210  BatchMatmulOp::getNumRegionArgs(),
4211  BatchMatmulOp::getRegionBuilder());
4212 }
4213 
4215  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4216  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4217  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4218  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4219  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4220 
4221  std::array<StringRef, 3> elidedAttrs = {
4222  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4223  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4224  elidedAttrs);
4225 }
4226 
4227 /// Verify the user defined indexing maps.
4228 LogicalResult BatchMatmulOp::verify() {
4229  // Verification of pure batch_matmul is handled by
4230  // verifyStructuredOpInterface().
4231  if (!hasUserDefinedMaps())
4232  return success();
4233 
4234  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4235  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
4236  return failure();
4237  }
4238  return success();
4239 }
4240 
4241 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4243  return memref::foldMemRefCast(*this);
4244 }
4245 
4246 void BatchMatmulOp::getEffects(
4248  &effects) {
4249  if (hasPureTensorSemantics())
4250  return;
4251  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4252 }
4253 
4254 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4255  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4256 }
4257 
4258 //===----------------------------------------------------------------------===//
4259 // ElementwiseOp
4260 //===----------------------------------------------------------------------===//
4261 //
4262 namespace {
4263 struct ArityGroupAndKind {
4264  // The enum class {Unary, Binary, Ternary, ..}
4265  ElementwiseArityGroup arityGroup;
4266 
4267  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4268  union Kind {
4269  UnaryFn unaryFn;
4270  BinaryFn binaryFn;
4271  TernaryFn ternaryFn;
4272  } kind;
4273 };
4274 
4275 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4276  return static_cast<unsigned>(arityGroup);
4277 }
4278 } // namespace
4279 
4280 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4281  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4282  constexpr int lastBinary =
4283  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4284  constexpr int lastTernary =
4285  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4286 
4287  int val = static_cast<int>(kind);
4288  ArityGroupAndKind result;
4289 
4290  if (val < lastUnary) {
4291  result.arityGroup = ElementwiseArityGroup::Unary;
4292  result.kind.unaryFn = static_cast<UnaryFn>(val);
4293  return result;
4294  }
4295  if (val < lastBinary) {
4296  result.arityGroup = ElementwiseArityGroup::Binary;
4297  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4298  return result;
4299  }
4300  if (val >= lastTernary) {
4301  llvm_unreachable("unhandled ElementwiseFn");
4302  }
4303  result.arityGroup = ElementwiseArityGroup::Ternary;
4304  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4305  return result;
4306 }
4307 
4308 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4309  auto rank = getResultRank();
4310  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4311 }
4312 
4314 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4315  MLIRContext *context) {
4316  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4317  return SmallVector<AffineMap>(numMaps, map);
4318 }
4319 
4320 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4321  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4322  Attribute attr;
4323  mlir::linalg::ElementwiseKind elemwiseKindVal;
4324  if (parser.parseKeyword("kind") || parser.parseEqual())
4325  return failure();
4326 
4327  if (succeeded(parser.parseAttribute(attr))) {
4328  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4329  if (!elemwiseKindAttr)
4330  return parser.emitError(parser.getCurrentLocation(),
4331  "expected ElementwiseKind attribute");
4332  elemwiseKindVal = elemwiseKindAttr.getValue();
4333  } else {
4334  return parser.emitError(parser.getCurrentLocation(),
4335  "expected operation 'kind' attribute");
4336  }
4337  result.addAttribute(
4338  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4339 
4340  // Parse optional `indexing_maps`
4341  SmallVector<Attribute, 3> indexingMapsAttr;
4342  Attribute mapAttr;
4343  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4344  if (parser.parseEqual())
4345  return failure();
4346  if (parser.parseLSquare())
4347  return failure();
4348  do {
4349  if (parser.parseAttribute(mapAttr))
4350  return failure();
4351  if (!isa<AffineMapAttr>(mapAttr))
4352  return parser.emitError(parser.getCurrentLocation(),
4353  "expected affine map attribute");
4354  indexingMapsAttr.push_back(mapAttr);
4355  if (parser.parseOptionalComma())
4356  break;
4357  } while (true);
4358  if (parser.parseRSquare())
4359  return failure();
4360  }
4361  // At this stage of parsing the only way to infer number of region
4362  // args is through op kind, as input output tensors are not parsed yet.
4363  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4364  int numRegionArgs =
4365  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4366  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4367  ElementwiseOp::getRegionBuilder())) {
4368  return parser.emitError(parser.getCurrentLocation(),
4369  "unable to parse elemwise op");
4370  }
4371 
4372  // Initialize indexingMaps, if not supplied explicitly.
4373  if (indexingMapsAttr.empty()) {
4374  // We need to infer the numDims of the indexing maps from the output
4375  // type which is already parsed by now.
4376  auto resultType = result.operands[result.operands.size() - 1].getType();
4377  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4378  if (!shapedType)
4379  return parser.emitError(parser.getCurrentLocation(),
4380  "return type needs to be shaped type");
4381  auto numDims = shapedType.getRank();
4382  indexingMapsAttr = llvm::map_to_vector(
4383  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4384  parser.getContext()),
4385  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4386  }
4387 
4388  result.addAttribute("indexing_maps",
4389  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4390  return success();
4391 }
4392 
4394  p << " kind=";
4395  p.printAttribute(getKindAttr());
4396  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4397  "indexing_maps"};
4398  unsigned arity =
4399  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4400  unsigned numDims = getResultRank();
4401 
4402  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4403  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4404  getContext()),
4405  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4406 
4407  if (!llvm::equal(getIndexingMaps(), indexingMaps))
4408  p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4409 
4410  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4411  elidedAttrs);
4412 }
4413 
4414 LogicalResult ElementwiseOp::verify() {
4415  // All necessary checks are done either by
4416  // - EnumAttr (e.g. unknown operation kind)
4417  // - verifyStructuredOpInterface (incorrect map, sizes).
4418  return success();
4419 }
4420 
4421 /// Implements the block region builder for the ElementwiseOp. This is called by
4422 /// 'fillStructuredOpRegion'.
4423 void ElementwiseOp::regionBuilder(
4426  ElementwiseKind elemwiseKind;
4427  for (auto attr : attrs) {
4428  if (attr.getName() == b.getStringAttr("kind")) {
4429  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4430  assert(kindAttr && "op kind attribute incorrectly set");
4431  elemwiseKind = kindAttr.getValue();
4432  break;
4433  }
4434  }
4435 
4436  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4437  auto arityGroup = groupAndKind.arityGroup;
4438  auto kind = groupAndKind.kind;
4439  if (emitError && block.getNumArguments() !=
4440  getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4441  emitError() << "Elementwise regionBuilder expects "
4442  << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4443  << block.getNumArguments();
4444  return;
4445  }
4446  assert(block.getNumArguments() ==
4447  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4448  && "Elementwise regionBuilder number of block args mismatch");
4449 
4450  RegionBuilderHelper helper(b, block);
4451  SmallVector<Value> yields;
4452  Value result;
4453 
4454  if (arityGroup == ElementwiseArityGroup::Unary) {
4455  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4456 
4457  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4458  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4459  block.getArgument(1));
4460 
4461  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4462  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4463  block.getArgument(1), block.getArgument(2));
4464 
4465  } else {
4466  assert(false && "found unhandled category in elemwise");
4467  }
4468 
4469  yields.push_back(result);
4470  helper.yieldOutputs(yields);
4471 }
4472 
4473 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4475  return memref::foldMemRefCast(*this);
4476 }
4477 
4478 void ElementwiseOp::getEffects(
4480  &effects) {
4481  if (hasPureTensorSemantics())
4482  return;
4483  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4484 }
4485 
4486 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4487  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4488 }
4489 
4490 //===----------------------------------------------------------------------===//
4491 // PackOp/UnPackOp Common
4492 //===----------------------------------------------------------------------===//
4493 
4494 template <typename OpTy, typename>
4497  RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value)
4498  ? packOrUnPack.getDestType()
4499  : packOrUnPack.getSourceType();
4500  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4501  ? packOrUnPack.getSourceType()
4502  : packOrUnPack.getDestType();
4503  SmallVector<int64_t> result(
4504  packedType.getShape().take_front(unpackedType.getRank()));
4505  if (!packOrUnPack.getOuterDimsPerm().empty()) {
4507  result, invertPermutationVector(packOrUnPack.getOuterDimsPerm()));
4508  }
4509  return result;
4510 }
4515 
4516 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4517 // updated mixed-tile-sizes attribute. A tile size is updated only
4518 // when:
4519 // * a dim from newPackedTy is static, and
4520 // * the corresponding size from mixedTiles is still dynamic.
4521 // Otherwise, the original tile size is preserved.
4522 // Note - packed-type-dim and mixed-tile-size should always match!
4525  SmallVector<OpFoldResult> mixedTiles) {
4526  SmallVector<OpFoldResult> newMixedTileSizes;
4527  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4528  .getShape()
4529  .take_back(mixedTiles.size()),
4530  mixedTiles)) {
4531  int64_t shape = std::get<0>(it);
4532  if (shape == ShapedType::kDynamic) {
4533  newMixedTileSizes.push_back(std::get<1>(it));
4534  continue;
4535  }
4536 
4537  // If the current result dim is static, update the dynamic mixed-size
4538  // (provided the original value is dynamic).
4539  OpFoldResult tile = std::get<1>(it);
4540  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4541  // Already a constant
4542  newMixedTileSizes.push_back(tile);
4543  } else {
4544  assert(getConstantIntValue(tile).value() == shape &&
4545  "tile size and dim size don't match!");
4546  newMixedTileSizes.push_back(
4547  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4548  }
4549  }
4550 
4551  return newMixedTileSizes;
4552 }
4553 
4554 template <typename OpTy>
4555 static LogicalResult
4557  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4558  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4559  "applies to only pack or unpack operations");
4560  int64_t destRank = op.getDestRank();
4561  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4562  reifiedReturnShapes[0] =
4563  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4564  return success();
4565 }
4566 
4567 template <typename OpTy>
4569  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4570  "applies to only pack or unpack operations");
4571  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4572  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4573  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4574  assert(tiles.size() == dimsToTile.size() &&
4575  "tiles must match indices of dimension to block");
4576  // bind the dimension `i` with the tile factor.
4577  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4578  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4579  return dimAndTileMapping;
4580 }
4581 
4582 template <typename OpTy>
4584  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4585  "applies to only pack or unpack operations");
4586  Builder builder(op);
4587  SmallVector<OpFoldResult> mixedInnerTiles;
4588  unsigned dynamicValIndex = 0;
4589  for (int64_t staticTile : op.getStaticInnerTiles()) {
4590  if (ShapedType::isStatic(staticTile))
4591  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4592  else
4593  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4594  }
4595  return mixedInnerTiles;
4596 }
4597 
4598 template <typename OpTy>
4600  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4601  "applies to only pack or unpack operations");
4602  SmallVector<Value> dynamicTiles;
4603  SmallVector<int64_t> staticTiles;
4604  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4605  return staticTiles;
4606 }
4607 
4608 /// Returns true if `dimsPos` is invalid. It is invalid when:
4609 /// a) It contains duplicate.
4610 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4611 /// c) The number of elements in `dimsPos` is > than `rank`.
4613  size_t rank) {
4614  size_t dimsPosSize = dimsPos.size();
4615  if (dimsPosSize > rank)
4616  return true;
4617  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4618  if (dimsPosSize != uniqued.size())
4619  return true;
4620  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4621  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4622  });
4623 }
4624 
4625 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4626 /// of the `limitShape`.
4627 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4628  ArrayRef<int64_t> limitShape) {
4629  assert(
4630  sourceShape.size() == limitShape.size() &&
4631  "expected source shape rank, and limit of the shape to have same rank");
4632  return llvm::all_of(
4633  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4634  int64_t sourceExtent = std::get<0>(it);
4635  int64_t limit = std::get<1>(it);
4636  return ShapedType::isDynamic(sourceExtent) ||
4637  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4638  });
4639 }
4640 
4641 template <typename OpTy>
4642 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4643  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4644  "applies to only pack or unpack operations");
4645  Operation *op = packOrUnPack.getOperation();
4646 
4647  // Return true if we have a zero-value tile.
4648  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4649  return llvm::any_of(tiles, isZeroInteger);
4650  };
4651 
4652  // Verify tiles. Do not allow zero tiles.
4653  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4654  if (hasZeros(mixedTiles))
4655  return op->emitError("invalid zero tile factor");
4656 
4657  // Verify inner_dims_pos and outer_dims_perm.
4658  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4659  ? packOrUnPack.getSourceType()
4660  : packOrUnPack.getDestType();
4661  size_t unpackedRank = unpackedType.getRank();
4662  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4663  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4665  return op->emitError("invalid inner_dims_pos vector");
4666  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4667  return op->emitError("invalid outer_dims_perm vector");
4668  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4669  return op->emitError("outer_dims_perm must be a permutation or empty");
4670 
4671  // Tiling factors must be less than or equal to the input rank for pack (or
4672  // output rank for unpack), and must match the number of `inner_dims_pos`.
4673  if (mixedTiles.size() > unpackedRank) {
4674  return op->emitError("tiling factors must be less than or equal to the "
4675  "input rank for pack or output rank for unpack");
4676  }
4677  if (mixedTiles.size() != innerDimsPos.size()) {
4678  return op->emitError(
4679  "tiling factors must equal the number of dimensions to tile");
4680  }
4681 
4682  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4683  ? packOrUnPack.getDestType()
4684  : packOrUnPack.getSourceType();
4685  size_t packedRank = packedType.getRank();
4686  // Require output rank to match input rank + number of blocking factors.
4687  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4688  if (expectedPackedRank != packedRank) {
4689  return op->emitError(
4690  "packed rank != (unpacked rank + num tiling factors), got ")
4691  << packedRank << " != " << expectedPackedRank;
4692  }
4693 
4694  // Verify result shape is greater than the minimum expected
4695  // by the pack operation, and that the output shape
4696  // represents full tiles.
4697  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4698  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4699  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4700  return op->emitError("the shape of output is not large enough to hold the "
4701  "packed data. Expected at least ")
4702  << expectedPackedType << ", got " << packedType;
4703  }
4704  if (!llvm::all_of(
4705  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4706  mixedTiles),
4707  [](std::tuple<int64_t, OpFoldResult> it) {
4708  int64_t shape = std::get<0>(it);
4709  if (Attribute attr =
4710  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4711  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4712  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4713  return shape == staticTileSize;
4714  }
4715  return ShapedType::isDynamic(shape);
4716  })) {
4717  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4718  "tiled dimension in the packed type");
4719  }
4720  return success();
4721 }
4722 
4723 namespace {
4724 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4725 /// various permutations to the op.
4726 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4727 // these. These may or may not become true foldings / canonicalizations
4728 // depending on how aggressive we want to be in automatically folding
4729 // transposes.
4730 struct PackOrUnPackTransposeResult {
4734 };
4735 } // namespace
4736 
4737 template <typename OpTy>
4738 static PackOrUnPackTransposeResult
4740  ArrayRef<int64_t> innerPermutation,
4741  ArrayRef<int64_t> outerPermutation) {
4742  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4743  "applies to only pack or unpack operations");
4744  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4745  "some permutation must be non-empty");
4746  PackOrUnPackTransposeResult metadata;
4747  metadata.innerDimsPos =
4748  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4749  metadata.innerTiles =
4750  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4751  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4752  ? packOrUnPackOp.getSourceRank()
4753  : packOrUnPackOp.getDestRank();
4754  metadata.outerDimsPerm =
4755  packOrUnPackOp.getOuterDimsPerm().empty()
4756  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4757  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4758  if (!innerPermutation.empty()) {
4759  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4760  isPermutationVector(innerPermutation) &&
4761  "invalid inner permutation");
4762  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4763  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4764  }
4765  if (!outerPermutation.empty()) {
4766  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4767  isPermutationVector(outerPermutation) &&
4768  "invalid outer permutation");
4769  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4770  }
4771  return metadata;
4772 }
4773 
4774 //===----------------------------------------------------------------------===//
4775 // PackOp
4776 //===----------------------------------------------------------------------===//
4777 
4778 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4779  setNameFn(getResult(), "pack");
4780 }
4781 
4782 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4785  std::optional<Value> paddingValue,
4787  assert(innerDimsPos.size() == innerTiles.size() &&
4788  "number of tile sizes specified must match the specified number of "
4789  "original dimensions to be tiled");
4790  SmallVector<int64_t> staticTileSizes;
4791  SmallVector<Value> dynamicTileSizes;
4792  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4793  build(builder, state, dest.getType(), source, dest,
4794  paddingValue ? *paddingValue : nullptr,
4795  outerDimsPerm.empty() ? nullptr
4797  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4798  builder.getDenseI64ArrayAttr(staticTileSizes));
4799 }
4800 
4801 LogicalResult
4803  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4804  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4805 }
4806 
4807 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4808  return getDimAndTileMappingImpl(*this);
4809 }
4810 
4811 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4812  return getMixedTilesImpl(*this);
4813 }
4814 
4815 SmallVector<int64_t> PackOp::getStaticTiles() {
4816  return getStaticTilesImpl(*this);
4817 }
4818 
4819 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4820  ShapedType inputType = getSourceType();
4821  int64_t inputRank = inputType.getRank();
4822  return getDestType().getShape().take_front(inputRank);
4823 }
4824 
4825 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4826  auto innerDimsPos = getInnerDimsPos();
4827  auto packedShape = getDestType().getShape();
4829 
4830  for (auto index : innerDimsPos)
4831  res.push_back(packedShape[index]);
4832 
4833  return res;
4834 }
4835 
4836 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4838  ArrayRef<int64_t> outputShape,
4841  SmallVector<int64_t> outputTileSizes(
4842  outputShape.take_front(inputShape.size()));
4843  if (!outerDimsPerm.empty()) {
4844  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4845  "expected output and outer_dims_perm to have same size");
4846  applyPermutationToVector(outputTileSizes,
4848  }
4849  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4850  if (ShapedType::isDynamic(inputShape[pos]))
4851  continue;
4852  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4853 
4854  if (!constantTile) {
4855  if (ShapedType::isStatic(outputTileSizes[pos]) &&
4856  (inputShape[pos] % outputTileSizes[pos] != 0))
4857  return true;
4858  } else if (inputShape[pos] % (*constantTile) != 0) {
4859  return true;
4860  }
4861  }
4862  return false;
4863 }
4864 
4865 LogicalResult PackOp::verify() {
4866  if (failed(commonVerifierPackAndUnPackOp(*this)))
4867  return failure();
4868 
4869  // Verify padding value, and bail out if the tile does not divide the
4870  // dimension fully. In the case of dynamic tile factors or dimensions, having
4871  // a partial tile is undefined behavior.
4872  auto paddingValue = getPaddingValue();
4873  if (paddingValue &&
4874  paddingValue.getType() != getSourceType().getElementType()) {
4875  return emitOpError("expected padding_value has ")
4876  << getSourceType().getElementType()
4877  << " but got: " << paddingValue.getType();
4878  }
4879 
4880  if (!paddingValue &&
4881  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4882  getDestType().getShape(), getOuterDimsPerm(),
4883  getMixedTiles())) {
4884  return emitOpError(
4885  "invalid tile factor or output size provided. Only full tiles are "
4886  "supported when padding_value is not set");
4887  }
4888  return success();
4889 }
4890 
4891 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4892 /// Value's to kDynamic, even if they are arith.constant values.
4893 static SmallVector<int64_t>
4895  SmallVector<int64_t> result;
4896  for (auto o : ofrs) {
4897  // Have to do this first, as getConstantIntValue special-cases constants.
4898  if (llvm::dyn_cast_if_present<Value>(o))
4899  result.push_back(ShapedType::kDynamic);
4900  else
4901  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4902  }
4903  return result;
4904 }
4905 
4906 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4907 /// the packed type. Having a shared helper helps implement these two methods in
4908 /// a way that ensures that they agree on which dimensions are dynamic.
4910  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4912  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4913  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4914  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4915  continue;
4916  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4917  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4918  continue;
4919  }
4920  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4921  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4922  }
4923 
4924  // Swap tile loops if outer_dims_perm is available.
4925  if (!outerDimsPerm.empty())
4927 
4928  // Append the inner tile dimensions.
4929  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4930  return resultShape;
4931 }
4932 
4933 SmallVector<OpFoldResult> PackOp::getResultShape(
4934  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4937  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4938 
4939  AffineExpr s0, s1;
4940  bindSymbols(builder.getContext(), s0, s1);
4941  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4942  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4943  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4944  builder, loc, ceilDivExpr,
4945  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4946  }
4947  if (!outerDimsPerm.empty())
4949  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4950 
4951  SmallVector<int64_t> resultTypeShape =
4953  asShapeWithAnyValueAsDynamic(innerTileSizes),
4955 
4956  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4957  // result type shape says it's a dynamic dim. This is needed as callers may
4958  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4959  // dynamic dims returned by that.
4960  for (unsigned i = 0; i < resultDims.size(); ++i) {
4961  if (ShapedType::isStatic(resultTypeShape[i]))
4962  continue;
4963  resultDims[i] =
4964  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4965  }
4966 
4967  return resultDims;
4968 }
4969 
4970 /// Get the expected packed type based on source type, tile factors, position of
4971 /// the inner tiles and permutation of the outer tiled loop.
4972 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4973  ArrayRef<int64_t> innerTileSizes,
4977  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4978  return RankedTensorType::get(resultShape, sourceType.getElementType());
4979 }
4980 
4981 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4982  ArrayRef<OpFoldResult> innerTileSizes,
4985  AffineExpr dim0, dim1;
4986  bindDims(b.getContext(), dim0, dim1);
4987  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4988  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4989  {v1, v2});
4990  };
4991 
4992  SmallVector<OpFoldResult> mixedSizes;
4993  for (auto [index, value] : llvm::enumerate(
4994  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4995  if (ShapedType::isDynamic(value))
4996  mixedSizes.push_back(
4997  tensor::DimOp::create(b, loc, source, index).getResult());
4998  else
4999  mixedSizes.push_back(b.getIndexAttr(value));
5000  }
5001  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
5002  int64_t dimPos = std::get<0>(it);
5003  OpFoldResult tileSize = std::get<1>(it);
5004  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
5005  }
5006  if (!outerDimsPerm.empty())
5007  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
5008 
5009  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
5010  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
5011  return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5012 }
5013 
5014 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
5015  ArrayRef<int64_t> innerPermutation,
5016  ArrayRef<int64_t> outerPermutation) {
5017  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5018  *this, innerPermutation, outerPermutation);
5019  Value transposedDest =
5020  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
5021  metadata.innerDimsPos, metadata.outerDimsPerm);
5022  return PackOp::create(b, loc, getSource(), transposedDest,
5023  metadata.innerDimsPos, metadata.innerTiles,
5024  getPaddingValue(), metadata.outerDimsPerm);
5025 }
5026 
5027 /// Returns true if the tiles and the tiled dims are constant.
5028 template <typename OpTy>
5030  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5031  "applies to only pack or unpack operations");
5032  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5033  ? op.getDestType()
5034  : op.getSourceType();
5035  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5036  for (auto [dimDest, tile] : llvm::zip(
5037  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
5038  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
5039  if (!constTileSize || ShapedType::isDynamic(dimDest))
5040  return false;
5041  }
5042  return true;
5043 }
5044 
5045 Speculation::Speculatability PackOp::getSpeculatability() {
5046  if (getPaddingValue())
5048 
5049  // The verifier rejects already operations if we can statically prove that the
5050  // sizes of the tiles do not divide perfectly the dimension; thus, check only
5051  // to have constant tiles and tiled inner dimensions.
5052  if (!areTilesAndTiledDimsAllConstant(*this))
5054 
5056 }
5057 
5058 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5059 // dimensions for pack and unpack.
5060 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5061  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5062  return false;
5063  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5064  return true;
5065  // Outer dims permutation is optional.
5066  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5067  // identity permutation.
5068  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
5069  isIdentityPermutation(unPackOp.getOuterDimsPerm());
5070 }
5071 
5072 // Return true if pack and unpack have the same tiles.
5073 // Same SSA values or same integer constants.
5074 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5075  auto packTiles = packOp.getMixedTiles();
5076  auto unPackTiles = unPackOp.getMixedTiles();
5077  if (packTiles.size() != unPackTiles.size())
5078  return false;
5079  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5080  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
5081  return false;
5082  }
5083  return true;
5084 }
5085 
5086 /// Returns true if the pack op does not need a padding value.
5087 static bool paddingIsNotNeeded(PackOp op) {
5088  auto srcType = op.getSourceType();
5089  if (llvm::any_of(op.getInnerDimsPos(),
5090  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
5091  return false;
5092  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
5093  return false;
5094  return !PackOp::requirePaddingValue(
5095  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
5096  op.getOuterDimsPerm(), op.getMixedTiles());
5097 }
5098 
5099 /// Returns true if the `srcShape` or `destShape` is different from the one in
5100 /// `packOp` and populates each with the inferred static shape.
5101 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5102  SmallVectorImpl<int64_t> &destShape) {
5103  bool changeNeeded = false;
5104  srcShape.assign(packOp.getSourceType().getShape().begin(),
5105  packOp.getSourceType().getShape().end());
5106  destShape.assign(packOp.getDestType().getShape().begin(),
5107  packOp.getDestType().getShape().end());
5108  llvm::SmallSetVector<int64_t, 4> innerDims;
5109  innerDims.insert_range(packOp.getInnerDimsPos());
5110  SmallVector<int64_t> inverseOuterDimsPerm;
5111  if (!packOp.getOuterDimsPerm().empty())
5112  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
5113  int srcRank = packOp.getSourceRank();
5114  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
5115  if (innerDims.contains(i))
5116  continue;
5117  int64_t srcPos = i;
5118  int64_t destPos = i;
5119  if (!inverseOuterDimsPerm.empty())
5120  destPos = inverseOuterDimsPerm[srcPos];
5121  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5122  ShapedType::isDynamic(destShape[destPos])) {
5123  continue;
5124  }
5125  int64_t size = srcShape[srcPos];
5126  if (ShapedType::isDynamic(size))
5127  size = destShape[destPos];
5128  srcShape[srcPos] = size;
5129  destShape[destPos] = size;
5130  changeNeeded = true;
5131  }
5132  return changeNeeded;
5133 }
5134 
5135 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5136  // Fold an pack(unpack(x)) to x.
5137  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5138  if (unPackOp.getSourceType() != packOp.getDestType())
5139  return failure();
5140  if (packOp.getPaddingValue() ||
5141  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5142  !haveSameTiles(packOp, unPackOp))
5143  return failure();
5144  rewriter.replaceOp(packOp, unPackOp.getSource());
5145  return success();
5146  }
5147 
5148  // Fold optional PaddingValue operand away if padding is not needed.
5149  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5150  rewriter.startOpModification(packOp);
5151  packOp.getPaddingValueMutable().clear();
5152  rewriter.finalizeOpModification(packOp);
5153  return success();
5154  }
5155 
5156  // Insert tensor.cast ops if static shape inference is available..
5157  SmallVector<int64_t> srcShape, destShape;
5158  if (inferStaticShape(packOp, srcShape, destShape)) {
5159  Location loc = packOp.getLoc();
5160  Value source = packOp.getSource();
5161  if (srcShape != packOp.getSourceType().getShape()) {
5162  auto newSrcType = packOp.getSourceType().clone(srcShape);
5163  source =
5164  tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource());
5165  }
5166  Value dest = packOp.getDest();
5167  RankedTensorType originalResultType = packOp.getDestType();
5168  bool needUpdateDestType = (destShape != originalResultType.getShape());
5169  if (needUpdateDestType) {
5170  auto newDestType = packOp.getDestType().clone(destShape);
5171  dest =
5172  tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest());
5173  }
5174  rewriter.modifyOpInPlace(packOp, [&] {
5175  packOp.getSourceMutable().assign(source);
5176  packOp.getDestMutable().assign(dest);
5177  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5178  });
5179  // Insert a cast if needed
5180  if (needUpdateDestType) {
5181  rewriter.setInsertionPointAfter(packOp);
5182  auto castOp =
5183  tensor::CastOp::create(rewriter, loc, originalResultType, packOp);
5184  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5185  }
5186  return success();
5187  }
5188 
5189  return failure();
5190 }
5191 
5192 template <typename PackOrUnpackOp>
5193 static bool isLikePadUnPad(PackOrUnpackOp packOp,
5194  RankedTensorType packedTensorType) {
5195  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5196  std::is_same<PackOrUnpackOp, UnPackOp>::value,
5197  "Function meant for pack/unpack");
5198  // This is a pad if packing only adds ones and we don't transpose dimensions.
5199 
5200  // Check that we are not transposing any dimensions.
5201  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5202  int64_t numPackedDims = innerDimsPos.size();
5203  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
5204  if (orderedDims != innerDimsPos) {
5205  // Dimensions don't happen in order.
5206  return false;
5207  }
5208 
5209  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5210  int64_t packedRank = packedTensorType.getRank();
5211  // At this point we know that we are taking numPackedDims outer
5212  // dimensions and pushing them all the way as the inner most dimensions.
5213  // What's left on the outer most dimensions is, in this order:
5214  // - the factor of the packed dimensions, then
5215  // - the untouched dimensions
5216  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5217  // if all the dimensions that bubble outerward are ones.
5218  // Therefore check that all the dimensions but the numPackedDims inner most
5219  // ones are ones.
5220  return llvm::all_of(
5221  llvm::seq<int64_t>(0, packedRank - numPackedDims),
5222  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5223 }
5224 
5225 bool PackOp::isLikePad() {
5226  auto packedTensorType =
5227  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5228  return isLikePadUnPad(*this, packedTensorType);
5229 }
5230 
5231 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5232  std::optional<Attribute> paddingValue;
5233  if (auto pad = adaptor.getPaddingValue())
5234  paddingValue = pad;
5235  if (OpFoldResult reshapedSource = reshapeConstantSource(
5236  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5237  getDestType(), paddingValue))
5238  return reshapedSource;
5239  return {};
5240 }
5241 
5242 /// Folds a tensor.cast op into a consuming PackOp op if the
5243 /// `tensor.cast` has source that is more static than the consuming op.
5244 ///
5245 /// Example:
5246 /// ```mlir
5247 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5248 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5249 /// ```
5250 ///
5251 /// folds into:
5252 ///
5253 /// ```mlir
5254 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5255 /// ```
5256 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5258 
5259  LogicalResult matchAndRewrite(PackOp op,
5260  PatternRewriter &rewriter) const override {
5262  return failure();
5263 
5264  SmallVector<Type> newResultTypes(op->getResultTypes());
5265  SmallVector<Value> newOperands =
5267 
5268  // Get the updated mixed-tile-sizes attribute.
5269  SmallVector<OpFoldResult> newMixedTileSizes =
5270  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5271 
5272  // Clone op.
5273  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5274  // this point. However, in practice, we use them for things that we'd like
5275  // to preserve. Implement a better abstraction.
5276  PackOp newOp =
5277  PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1],
5278  op.getInnerDimsPos(), newMixedTileSizes,
5279  op.getPaddingValue(), op.getOuterDimsPerm());
5280  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5281 
5282  // Replace op.
5283  Value oldResult = op.getResult();
5284  Value newResult = newOp.getResult();
5285  Value replacement =
5286  (newResult.getType() != oldResult.getType())
5287  ? tensor::CastOp::create(rewriter, op->getLoc(),
5288  oldResult.getType(), newResult)
5289  : newResult;
5290 
5291  rewriter.replaceOp(op, {replacement});
5292 
5293  return success();
5294  }
5295 };
5296 
5297 //===----------------------------------------------------------------------===//
5298 // UnPackOp
5299 //===----------------------------------------------------------------------===//
5300 
5301 void UnPackOp::getAsmResultNames(
5302  function_ref<void(Value, StringRef)> setNameFn) {
5303  setNameFn(getResult(), "unpack");
5304 }
5305 
5306 LogicalResult
5308  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5309  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5310 }
5311 
5312 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5313  return getDimAndTileMappingImpl(*this);
5314 }
5315 
5316 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5317  return getMixedTilesImpl(*this);
5318 }
5319 
5320 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5321  return getStaticTilesImpl(*this);
5322 }
5323 
5324 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5325  ShapedType destType = getDestType();
5326  int64_t destRank = destType.getRank();
5327  return getSourceType().getShape().take_front(destRank);
5328 }
5329 
5330 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5331  auto innerDimsPos = getInnerDimsPos();
5332  auto packedShape = getSourceType().getShape();
5334 
5335  for (auto index : innerDimsPos)
5336  res.push_back(packedShape[index]);
5337 
5338  return res;
5339 }
5340 
5341 LogicalResult UnPackOp::verify() {
5342  return commonVerifierPackAndUnPackOp(*this);
5343 }
5344 
5345 Speculation::Speculatability UnPackOp::getSpeculatability() {
5346  // See PackOp::getSpeculatability.
5347  if (!areTilesAndTiledDimsAllConstant(*this))
5349 
5351 }
5352 
5353 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5357  assert(innerDimsPos.size() == innerTiles.size() &&
5358  "number of tile sizes specified must match the specified number of "
5359  "original dimensions to be tiled");
5360  SmallVector<int64_t> staticTileSizes;
5361  SmallVector<Value> dynamicTileSizes;
5362  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5363  build(builder, state, dest.getType(), source, dest,
5364  outerDimsPerm.empty() ? nullptr
5366  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5367  builder.getDenseI64ArrayAttr(staticTileSizes));
5368 }
5369 
5370 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5371  Value source,
5372  ArrayRef<OpFoldResult> innerTileSizes,
5375  AffineExpr sym0, sym1;
5376  bindSymbols(b.getContext(), sym0, sym1);
5377  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5378  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5379  };
5380 
5381  SmallVector<OpFoldResult> mixedSizes;
5382  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5383  for (auto i :
5384  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5385  if (srcType.isDynamicDim(i))
5386  mixedSizes.push_back(
5387  tensor::DimOp::create(b, loc, source, i).getResult());
5388  else
5389  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5390  }
5391  if (!outerDimsPerm.empty()) {
5392  applyPermutationToVector<OpFoldResult>(
5393  mixedSizes, invertPermutationVector(outerDimsPerm));
5394  }
5395 
5396  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5397  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5398 
5399  auto elemType = srcType.getElementType();
5400  return tensor::EmptyOp::create(b, loc, mixedSizes, elemType);
5401 }
5402 
5403 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5404  Value transposedSource,
5405  ArrayRef<int64_t> innerPermutation,
5406  ArrayRef<int64_t> outerPermutation) {
5407  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5408  *this, innerPermutation, outerPermutation);
5409  return UnPackOp::create(b, loc, transposedSource, getDest(),
5410  metadata.innerDimsPos, metadata.innerTiles,
5411  metadata.outerDimsPerm);
5412 }
5413 
5414 /// Returns true if the `srcShape` or `destShape` is different from the one in
5415 /// `op` and populates each with the inferred static shape.
5416 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5417  SmallVectorImpl<int64_t> &destShape) {
5418  bool changeNeeded = false;
5419  srcShape.assign(op.getSourceType().getShape().begin(),
5420  op.getSourceType().getShape().end());
5421  destShape.assign(op.getDestType().getShape().begin(),
5422  op.getDestType().getShape().end());
5423  llvm::SmallSetVector<int64_t, 4> innerDims;
5424  innerDims.insert_range(op.getInnerDimsPos());
5425  SmallVector<int64_t> inverseOuterDimsPerm;
5426  if (!op.getOuterDimsPerm().empty())
5427  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5428  int destRank = op.getDestRank();
5429  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5430  if (innerDims.contains(i))
5431  continue;
5432  int64_t srcPos = i;
5433  int64_t destPos = i;
5434  if (!inverseOuterDimsPerm.empty())
5435  srcPos = inverseOuterDimsPerm[destPos];
5436  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5437  ShapedType::isDynamic(destShape[destPos])) {
5438  continue;
5439  }
5440  int64_t size = srcShape[srcPos];
5441  if (ShapedType::isDynamic(size))
5442  size = destShape[destPos];
5443  srcShape[srcPos] = size;
5444  destShape[destPos] = size;
5445  changeNeeded = true;
5446  }
5447  return changeNeeded;
5448 }
5449 
5450 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5451  PatternRewriter &rewriter) {
5452  /// unpack(pack(x)) -> x
5453  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5454  if (packOp.getSourceType() != unPackOp.getDestType())
5455  return failure();
5456  if (packOp.getPaddingValue() ||
5457  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5458  !haveSameTiles(packOp, unPackOp))
5459  return failure();
5460  rewriter.replaceOp(unPackOp, packOp.getSource());
5461  return success();
5462  }
5463  /// unpack(destinationStyleOp(x)) -> unpack(x)
5464  if (auto dstStyleOp =
5465  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5466  auto destValue = cast<OpResult>(unPackOp.getDest());
5467  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5468  rewriter.modifyOpInPlace(unPackOp,
5469  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5470  return success();
5471  }
5472  /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5473  if (unPackOp->hasOneUse()) {
5474  auto extractSliceUser =
5475  dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5476  if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) {
5477  OpBuilder::InsertionGuard g(rewriter);
5478  rewriter.setInsertionPoint(unPackOp);
5479  auto newDest = tensor::ExtractSliceOp::create(
5480  rewriter, unPackOp->getLoc(), unPackOp.getDest(),
5481  extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5482  extractSliceUser.getMixedStrides());
5483  rewriter.modifyOpInPlace(unPackOp, [&]() {
5484  unPackOp.setDpsInitOperand(0, newDest);
5485  unPackOp.getResult().setType(newDest.getType());
5486  });
5487  rewriter.replaceOp(extractSliceUser, unPackOp);
5488  return success();
5489  }
5490  }
5491 
5492  // Insert tensor.cast ops if static shape inference is available..
5493  SmallVector<int64_t> srcShape, destShape;
5494  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5495  Location loc = unPackOp.getLoc();
5496  Value source = unPackOp.getSource();
5497  if (srcShape != unPackOp.getSourceType().getShape()) {
5498  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5499  source = tensor::CastOp::create(rewriter, loc, newSrcType,
5500  unPackOp.getSource());
5501  }
5502  Value dest = unPackOp.getDest();
5503  if (destShape != unPackOp.getDestType().getShape()) {
5504  auto newDestType = unPackOp.getDestType().clone(destShape);
5505  dest = tensor::CastOp::create(rewriter, loc, newDestType,
5506  unPackOp.getDest());
5507  }
5508  Value newOp = UnPackOp::create(
5509  rewriter, loc, source, dest, unPackOp.getInnerDimsPos(),
5510  unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm());
5511  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5512  unPackOp, unPackOp.getResult().getType(), newOp);
5513  return success();
5514  }
5515 
5516  return failure();
5517 }
5518 
5519 bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) {
5520  // Rank-reduced folding is not supported.
5521  if (sliceOp.getResultType().getRank() != this->getDestType().getRank())
5522  return false;
5523  if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) ||
5524  !areAllConstantIntValue(sliceOp.getMixedStrides(), 1))
5525  return false;
5526  RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType();
5527  SmallVector<int64_t> outerShapeWithoutTranspose =
5529  for (auto [pos, tileSize] :
5530  llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) {
5531  if (unpackedTypeAfterFold.isDynamicDim(pos))
5532  return false;
5533  if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos]))
5534  return false;
5535  if (ShapedType::isDynamic(tileSize))
5536  return false;
5537  int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize -
5538  unpackedTypeAfterFold.getDimSize(pos);
5539  if (paddingSize >= tileSize)
5540  return false;
5541  }
5542  return true;
5543 }
5544 
5545 bool UnPackOp::isLikeUnPad() {
5546  RankedTensorType packedTensorType = getSourceType();
5547  return isLikePadUnPad(*this, packedTensorType);
5548 }
5549 
5550 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5551  if (OpFoldResult reshapedSource = reshapeConstantSource(
5552  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5553  getResult().getType()))
5554  return reshapedSource;
5555  return {};
5556 }
5557 
5558 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5559 /// `tensor.cast` has source that is more static than the consuming op.
5560 ///
5561 /// Example:
5562 /// ```mlir
5563 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5564 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5565 /// ```
5566 ///
5567 /// folds into:
5568 ///
5569 /// ```mlir
5570 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5571 /// ```
5572 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5574 
5575  LogicalResult matchAndRewrite(UnPackOp op,
5576  PatternRewriter &rewriter) const override {
5578  return failure();
5579 
5580  SmallVector<Type> newResultTypes(op->getResultTypes());
5581  SmallVector<Value> newOperands =
5583  Value sourceTensor = newOperands[0];
5584 
5585  // Get the updated mixed-tile-sizes attribute.
5586  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5587  rewriter, sourceTensor.getType(), op.getMixedTiles());
5588 
5589  // Clone op.
5590  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5591  // this point. However, in practice, we use them for things that we'd like
5592  // to preserve. Implement a better abstraction.
5593  UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor,
5594  newOperands[1], op.getInnerDimsPos(),
5595  newMixedTileSizes, op.getOuterDimsPerm());
5596  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5597 
5598  // Replace op.
5599  Value oldResult = op.getResult();
5600  Value newResult = newOp.getResult();
5601  Value replacement =
5602  (newResult.getType() != oldResult.getType())
5603  ? tensor::CastOp::create(rewriter, op->getLoc(),
5604  oldResult.getType(), newResult)
5605  : newResult;
5606 
5607  rewriter.replaceOp(op, {replacement});
5608 
5609  return success();
5610  }
5611 };
5612 
5613 //===----------------------------------------------------------------------===//
5614 // BatchReduceMatmulOp
5615 //===----------------------------------------------------------------------===//
5616 SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
5618  utils::IteratorType::reduction, utils::IteratorType::parallel,
5619  utils::IteratorType::parallel, utils::IteratorType::reduction};
5620 }
5621 
5623 BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5624  AffineExpr d0, d1, d2, d3;
5625  SmallVector<AffineMap> indexingMaps;
5626  bindDims(context, d0, d1, d2, d3);
5627  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
5628  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
5629  indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
5630  return indexingMaps;
5631 }
5632 
5633 unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5634 
5635 std::string BatchReduceMatmulOp::getLibraryCallName() {
5636  return generateLibraryCallName(getOperation());
5637 }
5638 
5639 /// Check if the op has broadcast and/or transpose semantic. Returns true if
5640 /// the user defined indexing maps are not equal to default map.
5641 bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5642  SmallVector<AffineMap, 3> defaultMaps =
5643  getDefaultIndexingMaps(this->getContext());
5644  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
5645  return defaultMaps != explicitMaps;
5646 }
5647 
5648 /// Returns true if the given bcastMap map is a valid broadcast map. A valid
5649 /// broadcast map must include K dimension.
5650 /// TODO: Strict inclusion of K dimension in the broadcast map is not
5651 /// necessary for both input matrices simultaneously. We can relax this
5652 /// condition to have K dimension for one input matrix map and infer the K
5653 /// dimension for other input matrix map from the one already having K
5654 /// dimension.
5655 bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5656  bool isLHS) {
5657  assert(bcastMap.getNumResults() < 3 &&
5658  "Expected less than 3 result dim expr.");
5659  bool isValid = false;
5660  enum Indices { batchPos, mPos, nPos, kPos };
5661  if (bcastMap.getNumResults() == 1) {
5662  AffineExpr expr = bcastMap.getResult(0);
5663  isValid = expr.isFunctionOfDim(kPos);
5664  } else if (bcastMap.getNumResults() == 2) {
5665  AffineExpr expr0 = bcastMap.getResult(0);
5666  AffineExpr expr1 = bcastMap.getResult(1);
5667  isValid =
5668  isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
5669  expr0.isFunctionOfDim(mPos)) &&
5670  expr1.isFunctionOfDim(kPos))
5671  : ((expr0.isFunctionOfDim(batchPos) &&
5672  expr1.isFunctionOfDim(kPos)) ||
5673  (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
5674  }
5675  return isValid;
5676 }
5677 
5678 void BatchReduceMatmulOp::regionBuilder(
5681  if (emitError && block.getNumArguments() != 3) {
5682  emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
5683  << block.getNumArguments();
5684  return;
5685  }
5686  assert(block.getNumArguments() == 3 &&
5687  "BatchReduceMatmulOp regionBuilder expects 3 args");
5688  RegionBuilderHelper helper(b, block);
5689  SmallVector<Value> yields;
5690 
5691  auto toType = block.getArgument(2).getType();
5692  Value castValA =
5693  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
5694  Value castValB =
5695  helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
5696  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5697  Value addVal =
5698  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
5699  yields.push_back(addVal);
5700  helper.yieldOutputs(yields);
5701 }
5702 
5703 ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
5704  OperationState &result) {
5705  SmallVector<Attribute, 3> indexingMapsAttr;
5706  Attribute mapAttr;
5707  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
5708  if (parser.parseEqual())
5709  return failure();
5710  if (parser.parseLSquare())
5711  return failure();
5712 
5713  do {
5714  if (parser.parseAttribute(mapAttr))
5715  return failure();
5716  if (!isa<AffineMapAttr>(mapAttr)) {
5717  return parser.emitError(parser.getCurrentLocation(),
5718  "expected affine map attribute");
5719  }
5720  indexingMapsAttr.push_back(mapAttr);
5721 
5722  if (parser.parseOptionalComma())
5723  break;
5724  } while (true);
5725 
5726  if (parser.parseRSquare())
5727  return failure();
5728  }
5729  // Initialize indexingMaps, if not supplied explicitly.
5730  if (indexingMapsAttr.empty()) {
5731  indexingMapsAttr = llvm::map_to_vector(
5732  BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
5733  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5734  }
5735  result.addAttribute("indexing_maps",
5736  parser.getBuilder().getArrayAttr(indexingMapsAttr));
5737  return ::parseNamedStructuredOp(parser, result,
5738  BatchReduceMatmulOp::getNumRegionArgs(),
5739  BatchReduceMatmulOp::getRegionBuilder());
5740 }
5741 
5743  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
5744  BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
5745  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5746 
5747  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5748  p << " indexing_maps = [";
5749  llvm::interleaveComma(getIndexingMaps(), p,
5750  [&](Attribute attr) { p.printAttribute(attr); });
5751  p << "]";
5752  }
5753 
5754  SmallVector<StringRef, 3> elidedAttrs = {
5755  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5756  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
5757  elidedAttrs);
5758 }
5759 
5760 /// Verify the user defined indexing maps.
5761 LogicalResult BatchReduceMatmulOp::verify() {
5762  // Verification of pure batch_reduce_matmul is handled by
5763  // verifyStructuredOpInterface().
5764  if (!hasUserDefinedMaps())
5765  return success();
5766 
5767  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5768  if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
5769  return failure();
5770  }
5771  return success();
5772 }
5773 LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5775  return memref::foldMemRefCast(*this);
5776 }
5777 void BatchReduceMatmulOp::getEffects(
5779  &effects) {
5780  if (hasPureTensorSemantics())
5781  return;
5782  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5783 }
5784 
5785 Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
5786  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5787 }
5788 
5789 } // namespace linalg
5790 } // namespace mlir
5791 
5792 //===----------------------------------------------------------------------===//
5793 // LinalgDialect
5794 //===----------------------------------------------------------------------===//
5795 
5796 void LinalgDialect::getCanonicalizationPatterns(
5797  RewritePatternSet &results) const {
5798  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5799  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5800 }
5801 
5803  Attribute value, Type type,
5804  Location loc) {
5805  return arith::ConstantOp::materialize(builder, value, type, loc);
5806 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3561
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, function_ref< InFlightDiagnostic()> emitError, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:124
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1942
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2893
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2965
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3539
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4733
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2431
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4732
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3554
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:315
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1775
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1823
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2938
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:188
static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp/BatchReduceMatmulOp has...
Definition: LinalgOps.cpp:3625
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1575
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:159
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4271
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1445
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:204
static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3587
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2915
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4265
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1326
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1293
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2324
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4731
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:359
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:1068
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:352
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:56
UnaryFn unaryFn
Definition: LinalgOps.cpp:4269
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1499
static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:221
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1604
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:390
union mlir::linalg::@1224::ArityGroupAndKind::Kind kind
static LogicalResult verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3666
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:397
BinaryFn binaryFn
Definition: LinalgOps.cpp:4270
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:241
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder, SMLoc loc)
Definition: LinalgOps.cpp:327
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static LogicalResult getResultTilePosition(RewriterBase &rewriter, ReductionTilingStrategy reductionStrategy, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ReductionTilingStrategy reductionStrategy, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, ValueRange ivs, ArrayRef< OpFoldResult > numThreads, ArrayRef< OpFoldResult > tileSizes, const SetVector< unsigned > &reductionDims)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:314
AffineExpr ceilDiv(uint64_t v) const
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:330
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
unsigned getNumDims() const
Definition: AffineMap.cpp:390
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:31
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:158
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:382
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:359
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:24
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:261
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:313
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:750
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition: Builders.h:621
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
This class represents an operand of an operation.
Definition: Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:226
This is a value defined by a result of an operation.
Definition: Value.h:447
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:459
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:769
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:831
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:686
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:614
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:598
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Block * getParentBlock()
Return the Block in which this Value is defined.
Definition: Value.cpp:46
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1278
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1331
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2768
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4556
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< PackOp >(PackOp)
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:5101
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:5193
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4894
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4612
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4627
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4599
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4909
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2423
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4280
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4739
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:103
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2464
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:5087
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2403
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2414
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4568
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4524
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5074
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:94
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:5060
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:5029
std::pair< int64_t, int64_t > getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr)
Converts the given WinogradConv2DFmr enumeration value to a pair of m and r parameters.
Definition: LinalgOps.cpp:3706
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4583
template SmallVector< int64_t > getPackedOuterShapeWithoutTransposition< UnPackOp >(UnPackOp)
std::optional< WinogradConv2DFmr > getWinogradConv2DFmr(int64_t m, int64_t r)
Converts the given m and r parameters to a WinogradConv2DFmr enumeration value.
Definition: LinalgOps.cpp:3696
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4642
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3808
SmallVector< int64_t > getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack)
Returns the outer shape in the packed domain before applying the transposition.
Definition: LinalgOps.cpp:4496
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:360
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:353
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:369
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:238
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:784
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1282
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:2085
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2088
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2110
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2113
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:330
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5256
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5259
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5572
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5575