MLIR 23.0.0git
TosaOps.cpp
Go to the documentation of this file.
1//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// This file implements the TOSA Specification:
11// https://www.mlplatform.org/tosa/tosa_spec.html
12//
13//===----------------------------------------------------------------------===//
14
25#include "mlir/IR/Matchers.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/SmallVectorExtras.h"
31#include "llvm/ADT/TypeSwitch.h"
32
33#include <numeric>
34#include <type_traits>
35
36using namespace mlir;
37using namespace mlir::tosa;
38
39#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc"
41
42//===----------------------------------------------------------------------===//
43// Tosa dialect interface includes.
44//===----------------------------------------------------------------------===//
45
46#include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc"
47#include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc"
48#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc"
49#include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc"
50
51namespace {
52#include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc"
53
54//===----------------------------------------------------------------------===//
55// Dialect Function Inliner Interface.
56//===----------------------------------------------------------------------===//
57struct TosaInlinerInterface : public DialectInlinerInterface {
58 using DialectInlinerInterface::DialectInlinerInterface;
59
60 //===--------------------------------------------------------------------===//
61 // Analysis Hooks.
62 //===--------------------------------------------------------------------===//
63
64 /// All operations can be inlined by default.
65 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
66 IRMapping &map) const final {
67 return true;
68 }
69
70 /// All regions with If and While parent operators can be inlined.
71 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
72 IRMapping &map) const final {
73 return (isa<tosa::IfOp>(dest->getParentOp()) ||
74 isa<tosa::WhileOp>(dest->getParentOp()));
75 }
76};
77
78/// This class implements the bytecode interface for the Tosa dialect.
79struct TosaDialectBytecodeInterface : public BytecodeDialectInterface {
80 TosaDialectBytecodeInterface(Dialect *dialect)
81 : BytecodeDialectInterface(dialect) {}
82
83 //===--------------------------------------------------------------------===//
84 // Attributes
85
86 Attribute readAttribute(DialectBytecodeReader &reader) const override {
87 return ::readAttribute(getContext(), reader);
88 }
89
90 LogicalResult writeAttribute(Attribute attr,
91 DialectBytecodeWriter &writer) const override {
92 return ::writeAttribute(attr, writer);
93 }
94
95 //===--------------------------------------------------------------------===//
96 // Types
97
98 Type readType(DialectBytecodeReader &reader) const override {
99 return ::readType(getContext(), reader);
100 }
101
102 LogicalResult writeType(Type type,
103 DialectBytecodeWriter &writer) const override {
104 return ::writeType(type, writer);
105 }
106
107 void writeVersion(DialectBytecodeWriter &writer) const final {
108 // TODO: Populate.
109 }
110
111 std::unique_ptr<DialectVersion>
112 readVersion(DialectBytecodeReader &reader) const final {
113 // TODO: Populate
114 reader.emitError("Dialect does not support versioning");
115 return nullptr;
116 }
117
118 LogicalResult upgradeFromVersion(Operation *topLevelOp,
119 const DialectVersion &version) const final {
120 return success();
121 }
122};
123
124} // namespace
125
126//===----------------------------------------------------------------------===//
127// TOSA control flow support.
128//===----------------------------------------------------------------------===//
129
130/// Returns the while loop body.
131SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
132 return {&getBodyGraph()};
133}
134
135//===----------------------------------------------------------------------===//
136// TOSA variable operator support.
137//===----------------------------------------------------------------------===//
138
140 return map_to_vector(shape, [](int64_t dim) {
141 return dim == -1 ? ShapedType::kDynamic : dim;
142 });
143}
144
145// returns type of variable op
146RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) {
147 Type elementType = variableOp.getType();
148 DenseIntElementsAttr varShapeAttr = variableOp.getVarShape();
149 auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
150 return RankedTensorType::get(shape, elementType);
151}
152
153//===----------------------------------------------------------------------===//
154// Tosa dialect initialization.
155//===----------------------------------------------------------------------===//
156
157void TosaDialect::initialize() {
158 addTypes<
159#define GET_TYPEDEF_LIST
160#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
161 >();
162 addOperations<
163#define GET_OP_LIST
164#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
165 >();
166 addAttributes<
167#define GET_ATTRDEF_LIST
168#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
169 >();
170 addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>();
171 declarePromisedInterfaces<
172 shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp,
173 ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp,
174 LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp,
175 LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp,
176 BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp,
177 NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp,
178 GreaterEqualOp, MatMulOp>();
179}
180
181Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value,
182 Type type, Location loc) {
183 // Tosa dialect constants only support ElementsAttr unlike standard dialect
184 // constant which supports all attributes.
185 if (llvm::isa<shapeType>(type) && llvm::isa<DenseIntElementsAttr>(value)) {
186 return tosa::ConstShapeOp::create(builder, loc, type,
187 llvm::cast<DenseIntElementsAttr>(value));
188 }
189 if (llvm::isa<ElementsAttr>(value))
190 return tosa::ConstOp::create(builder, loc, type,
191 llvm::cast<ElementsAttr>(value));
192 return nullptr;
193}
194
195//===----------------------------------------------------------------------===//
196// Parsers and printers
197//===----------------------------------------------------------------------===//
198
199namespace {
200
201ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType,
202 DenseElementsAttr &varShapeAttr,
203 TypeAttr &typeAttr) {
204 if (auto shapedType = dyn_cast<ShapedType>(parsedType)) {
205 if (!shapedType.hasRank())
206 return parser.emitError(parser.getCurrentLocation())
207 << "expected ranked type";
208
209 auto elementType = shapedType.getElementType();
210 typeAttr = TypeAttr::get(elementType);
211 ArrayRef<int64_t> shape = shapedType.getShape();
212 Builder builder(parser.getContext());
213 varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
214 return success();
215 }
216 return parser.emitError(parser.getCurrentLocation())
217 << "expected shaped type";
218}
219
220} // namespace
221
222// parses the optional initial value or type for a tosa variable
223// with initial value:
224// tosa.variable @name = dense<0.0> : tensor<1x8xf32>
225//
226// without initial value:
227// tosa.variable @name : tensor<1x8xf32>
229 OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr,
230 Attribute &initialValueAttr) {
231 if (succeeded(parser.parseOptionalEqual())) {
232 if (failed(parser.parseAttribute(initialValueAttr))) {
233 return parser.emitError(parser.getCurrentLocation())
234 << "expected attribute";
235 }
236 if (auto typedAttr = dyn_cast<TypedAttr>(initialValueAttr)) {
237 return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr,
238 typeAttr);
239 }
240 return parser.emitError(parser.getCurrentLocation())
241 << "expected Typed attr";
242 }
243
244 initialValueAttr = nullptr;
245 Type parsedType;
246 if (failed(parser.parseColonType(parsedType))) {
247 return parser.emitError(parser.getCurrentLocation())
248 << "expected type after colon";
249 }
250 return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr);
251}
252
254 OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr,
255 TypeAttr typeAttr, Attribute initialValueAttr) {
256 bool needsSpace = false;
257 if (!dyn_cast_or_null<TypedAttr>(initialValueAttr)) {
258 auto shape =
259 convertToMlirShape(to_vector(varShapeAttr.getValues<int64_t>()));
260 Type elementType = typeAttr.getValue();
261 RankedTensorType tensorType =
262 RankedTensorType::get(ArrayRef<int64_t>(shape), elementType);
263 auto tensorTypeAttr = TypeAttr::get(tensorType);
264 p << ": ";
265 p.printAttribute(tensorTypeAttr);
266 needsSpace = true; // subsequent attr value needs a space separator
267 }
268 if (initialValueAttr) {
269 if (needsSpace)
270 p << ' ';
271 p << "= ";
272 p.printAttribute(initialValueAttr);
273 }
274}
275
276namespace {
277
278// parse attributes with special handling for tosa enum attributes
279template <typename EnumType>
280ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
281 NamedAttrList &outAttrs) {
282 llvm::StringRef name;
283 if (parser.parseOptionalKeyword(&name) || parser.parseEqual())
284 return failure();
285
286 // special handling: rounding_mode accepts a *bare* RoundingMode enum
287 // keyword.
288 llvm::StringRef kw;
289 if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) {
290 if (name == "rounding_mode" &&
291 succeeded(parser.parseOptionalKeyword(&kw))) {
292 auto sym = symbolizeRoundingMode(kw);
293 if (!sym)
294 return parser.emitError(parser.getCurrentLocation())
295 << "invalid rounding_mode value: " << kw;
296 auto attr = RoundingModeAttr::get(parser.getContext(), sym.value());
297 outAttrs.push_back(NamedAttribute(name, attr));
298 return success();
299 }
300 }
301 // special handling: mode accepts a *bare* ResizeMode enum keyword.
302 if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) {
303 if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
304 auto sym = symbolizeResizeMode(kw);
305 if (!sym)
306 return parser.emitError(parser.getCurrentLocation())
307 << "invalid resize mode value: " << kw;
308 auto attr = ResizeModeAttr::get(parser.getContext(), sym.value());
309 outAttrs.push_back(NamedAttribute(name, attr));
310 return success();
311 }
312 }
313 // special handling: nan_mode accepts a *bare* NanPropagationMode enum
314 // keyword.
315 if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) {
316 if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) {
317 auto sym = symbolizeNanPropagationMode(kw);
318 if (!sym)
319 return parser.emitError(parser.getCurrentLocation())
320 << "invalid nan_mode value: " << kw;
321 auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value());
322 outAttrs.push_back(NamedAttribute(name, attr));
323 return success();
324 }
325 }
326
327 // special handling: block_size accepts a *bare* BlockSizeMode enum
328 if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
329 if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
330 auto sym = symbolizeBlockSize(kw);
331 if (!sym)
332 return parser.emitError(parser.getCurrentLocation())
333 << "invalid block_size value: " << kw;
334 auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
335 outAttrs.push_back(NamedAttribute(name, attr));
336 return success();
337 }
338 }
339
340 // Default path: parse any normal attribute literal, including fully qualified
341 // enum keyword
342 Attribute attr;
343 return parser.parseAttribute(attr, name, outAttrs);
344}
345
346template <typename EnumType>
347ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
348 // parse operands
350 if (parser.parseCommaSeparatedList(
351 [&]() { return parser.parseOperand(operands.emplace_back()); }))
352 return failure();
353
354 // Parse { attr-dict } with special handling for enum bare token
355 NamedAttrList attrs;
356 if (succeeded(parser.parseOptionalLBrace()) &&
357 failed(parser.parseOptionalRBrace())) {
358 do {
359 if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs))
360 return failure();
361 } while (succeeded(parser.parseOptionalComma()));
362 if (parser.parseRBrace())
363 return failure();
364 }
365
366 FunctionType fnTy;
367 if (parser.parseColonType(fnTy))
368 return failure();
369
370 // Resolve operands and types
371 if (failed(parser.resolveOperands(operands, fnTy.getInputs(),
372 parser.getCurrentLocation(),
373 result.operands)))
374 return failure();
375
376 result.addTypes(fnTy.getResults());
377 result.addAttributes(attrs);
378
379 return success();
380}
381
382void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
383 parser << namedAttr.getName().strref() << " = ";
384 auto attr = namedAttr.getValue();
385 if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) {
386 parser << roundingModeAttr.getValue();
387 } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) {
388 parser << resizeModeAttr.getValue();
389 } else if (auto nanPropagationModeAttr =
390 dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
391 parser << nanPropagationModeAttr.getValue();
392 } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
393 parser << blockSizeAttr.getValue();
394 } else {
395 parser.printAttribute(attr);
396 }
397}
398
399// print with special handling for default valued NanPropagationMode attribute
400void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) {
401 parser << " ";
402 parser.printOperands(op->getOperands());
403
404 NamedAttrList toPrint(op->getAttrs());
405 // remove default NanPropagate attribute
406 const auto kDefaultNanValue = NanPropagationMode::PROPAGATE;
407 for (auto attr : op->getAttrs()) {
408 if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) {
409 if (nanAttr.getValue() == kDefaultNanValue) {
410 // elide from toPrint
411 toPrint.erase(attr.getName());
412 break;
413 }
414 }
415 }
416
417 if (!toPrint.empty()) {
418 parser << " {";
419 llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) {
420 printNamedAttr(parser, namedAttr);
421 });
422 parser << "}";
423 }
424
425 parser << " : ";
426 parser.printFunctionalType(op);
427}
428
429// print with special handling for enums: RoundingMode, ResizeMode
430void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) {
431 parser << " ";
432 parser.printOperands(op->getOperands());
433
434 if (!op->getAttrs().empty()) {
435 parser << " {";
436 llvm::interleaveComma(op->getAttrs(), parser,
437 [&](const NamedAttribute namedAttr) {
438 printNamedAttr(parser, namedAttr);
439 });
440 parser << "}";
441 }
442
443 parser << " : ";
444 parser.printFunctionalType(op);
445}
446
447} // namespace
448
449ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) {
450 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
451}
452
453void RescaleOp::print(OpAsmPrinter &parser) {
454 printWithEnumHandling(parser, *this);
455}
456
457ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) {
458 return parseWithEnumHandling<tosa::RoundingMode>(parser, result);
459}
460
461void ApplyScaleOp::print(OpAsmPrinter &parser) {
462 printWithEnumHandling(parser, *this);
463}
464
465ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) {
466 return parseWithEnumHandling<tosa::ResizeMode>(parser, result);
467}
468
469void ResizeOp::print(OpAsmPrinter &parser) {
470 printWithEnumHandling(parser, *this);
471}
472
473ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) {
474 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
475}
476
477void ArgMaxOp::print(OpAsmPrinter &parser) {
478 printWithNanPropagationHandling(parser, *this);
479}
480
481ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) {
482 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
483}
484
485void MaxPool2dOp::print(OpAsmPrinter &parser) {
486 printWithNanPropagationHandling(parser, *this);
487}
488
489ParseResult MaxPool2dAdaptiveOp::parse(OpAsmParser &parser,
491 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
492}
493
494void MaxPool2dAdaptiveOp::print(OpAsmPrinter &parser) {
495 printWithNanPropagationHandling(parser, *this);
496}
497
498ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) {
499 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
500}
501
502void ClampOp::print(OpAsmPrinter &parser) {
503 printWithNanPropagationHandling(parser, *this);
504}
505
506ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) {
507 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
508}
509
510void MaximumOp::print(OpAsmPrinter &parser) {
511 printWithNanPropagationHandling(parser, *this);
512}
513
514ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) {
515 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
516}
517
518void MinimumOp::print(OpAsmPrinter &parser) {
519 printWithNanPropagationHandling(parser, *this);
520}
521
522ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) {
523 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
524}
525
526void ReduceMaxOp::print(OpAsmPrinter &parser) {
527 printWithNanPropagationHandling(parser, *this);
528}
529
530ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) {
531 return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result);
532}
533
534void ReduceMinOp::print(OpAsmPrinter &parser) {
535 printWithNanPropagationHandling(parser, *this);
536}
537
538ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
540 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
541}
542
543void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
544 printWithEnumHandling(parser, *this);
545}
546
547ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
549 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
550}
551
552void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
553 printWithEnumHandling(parser, *this);
554}
555
556ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
558 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
559}
560
561void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
562 printWithEnumHandling(parser, *this);
563}
564
565ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
567 return parseWithEnumHandling<tosa::BlockSize>(parser, result);
568}
569
570void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
571 printWithEnumHandling(parser, *this);
572}
573
574//===----------------------------------------------------------------------===//
575// Tosa utilities.
576//===----------------------------------------------------------------------===//
577
578static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
579 if (lhs % rhs != 0)
580 return std::nullopt;
581 return lhs / rhs;
582}
583
585 auto srcType = getElementTypeOrSelf(type);
586 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
587 srcType = getStorageElementTypeFromQuantized(quantType);
588 return srcType;
589}
590
594
595static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
596 Value valZp, StringRef name) {
598 Type eZpType = getStorageElementTypeOrSelf(valZp.getType());
599
600 bool bothInts =
601 mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
602 bool sameBitWidth =
603 (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());
604
605 if (!bothInts || !sameBitWidth) {
606 return op->emitOpError()
607 << "expected " << name << " and " << name
608 << "_zp to both be integer of the same bitwidth, but got " << eType
609 << " vs. " << eZpType;
610 }
611 return success();
612}
613
614// Create a pad-const const tensor with value of `val` of required data-type
616 Value src, int32_t val) {
617 const auto srcType = getElementTypeOrSelf(src);
618 const auto srcElemType = getStorageElementTypeOrSelf(src);
619 const auto padConstType = mlir::RankedTensorType::get({1}, srcType);
620 const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType);
621 const auto padConstAttr{
622 llvm::isa<FloatType>(srcElemType)
623 ? DenseElementsAttr::get(padConstEType,
624 builder.getFloatAttr(srcElemType, val))
625 : DenseElementsAttr::get(padConstEType,
626 builder.getIntegerAttr(srcElemType, val))};
627 return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
628}
629
631 if (dyn_cast<tosa::mxint8Type>(type))
632 return 8;
633 return type.getIntOrFloatBitWidth();
634}
635
636// Update dim size if current dim is dynamic, otherwise raise an error if sizes
637// do not match
638LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
639 const int64_t newDim,
640 const StringRef operandName,
641 const StringRef dimName) {
642 if (ShapedType::isDynamic(currDim)) {
643 currDim = newDim;
644 return success();
645 } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
646 return op->emitOpError("expected ")
647 << dimName << " of " << operandName << " to match size " << currDim
648 << ", got " << newDim;
649 }
650 return success();
651}
652
655 auto printDim = [&](int64_t dim) {
656 if (ShapedType::isDynamic(dim))
657 diag << "?";
658 else
659 diag << dim;
660 };
661
662 llvm::interleaveComma(shape, diag, printDim);
663}
664
666 Operation *op, const int64_t inputSize, const int64_t kernelSize,
667 const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
668 const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
669 const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
670 const llvm::StringRef padAfterName) {
671 if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
672 return success();
673
674 // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
675
676 const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
677 inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
678 stride);
679 if (!calculatedOutSizeMinusOne.has_value())
680 return op->emitOpError("expected input_")
681 << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
682 << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
683 << dimAxis << " to be wholly divisible by stride_" << dimAxis
684 << ", got (" << inputSize << " - 1 + " << padBefore << " + "
685 << padAfter << " - (" << kernelSize << " - 1) * " << dilation
686 << ") / " << stride;
687
688 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
689 if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
690 return op->emitOpError("calculated output ")
691 << dimName << " did not match expected: "
692 << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
693
694 return success();
695}
696
697//===----------------------------------------------------------------------===//
698// TOSA Operator Verifiers.
699//===----------------------------------------------------------------------===//
700
701template <typename T>
702static LogicalResult verifyConvOp(T op) {
703 const auto inputType = llvm::dyn_cast<TensorType>(op.getInput().getType());
704 const auto weightType = llvm::dyn_cast<TensorType>(op.getWeight().getType());
705
706 auto inputEType = inputType.getElementType();
707 auto weightEType = weightType.getElementType();
708 auto biasEType =
709 llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
710 auto resultEType =
711 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
712 bool biasIsFloat = llvm::isa<FloatType>(biasEType);
713 bool resultIsFloat = llvm::isa<FloatType>(resultEType);
714
715 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
716 inputEType = getStorageElementTypeFromQuantized(quantType);
717
718 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
719 weightEType = getStorageElementTypeFromQuantized(quantType);
720
721 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
722 biasEType = getStorageElementTypeFromQuantized(quantType);
723
724 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
725 resultEType = getStorageElementTypeFromQuantized(quantType);
726
727 if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
728 // for now, only enforce bias element type == result element type for
729 // float types.
730 op.emitOpError(
731 "expect both bias and result to have same element type, got ")
732 << biasEType << " and " << resultEType;
733 return failure();
734 }
735
736 if (isa<Float8E5M2Type>(inputEType) || isa<Float8E4M3FNType>(inputEType) ||
737 isa<Float8E5M2Type>(weightEType) || isa<Float8E4M3FNType>(weightEType)) {
738 if (inputEType != weightEType) {
739 op.emitOpError(
740 "expect both input and weight to have same element type, got ")
741 << inputEType << " and " << weightEType;
742 return failure();
743 }
744 }
745
746 bool inputIsFloat = llvm::isa<FloatType>(inputEType);
747 bool weightIsFloat = llvm::isa<FloatType>(weightEType);
748
749 // Either both must be float or both non-float.
750 if (inputIsFloat != weightIsFloat) {
751 op.emitOpError(
752 "expect both input and weight to be float or not together, got ")
753 << inputEType << " and " << weightEType;
754 return failure();
755 }
756
757 auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType());
758 if (inputEType != inputZpEType) {
759 return op.emitOpError("expect both input and its zero point are the same "
760 "element type, got ")
761 << inputEType << " and " << inputZpEType;
762 }
763
764 auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType());
765 if (weightEType != weightZpEType) {
766 return op.emitOpError("expect both weight and its zero point are the same "
767 "element type, got ")
768 << weightEType << " and " << weightZpEType;
769 }
770
771 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
772 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
773 return failure();
774
775 FailureOr<int64_t> maybeWZp = op.getWeightZeroPoint();
776 if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed())
777 return failure();
778
779 return success();
780}
781
782LogicalResult tosa::ConstOp::verify() {
783
784 auto attrType = llvm::dyn_cast<TensorType>(getValuesAttr().getType());
785 auto outputType = llvm::dyn_cast<TensorType>(getOutput().getType());
786
787 if (!attrType || !outputType) {
788 emitOpError("expected tensors for attr/result type");
789 return failure();
790 }
791
792 if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
793 outputType.getElementType())) {
794 if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
795 return success();
796 }
797
798 if (attrType.getElementType() != outputType.getElementType()) {
799 emitOpError("expected same attr/result element types");
800 return failure();
801 }
802
803 return success();
804}
805
806template <typename T>
807static LogicalResult verifyConvOpModes(T op) {
808 auto inputEType =
809 llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
810
811 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
812 inputEType = getStorageElementTypeFromQuantized(quantType);
813
814 auto accType = op.getAccType();
815 if (inputEType.isInteger(8) && !accType.isInteger(32))
816 return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
817 << accType;
818
819 if (inputEType.isInteger(16) && !accType.isInteger(48))
820 return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
821 << accType;
822
823 if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
824 !(accType.isF16() || accType.isF32()))
825 return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
826 << accType;
827
828 if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
829 return op.emitOpError(
830 "accumulator type for f16 tensor is not f16/f32, got ")
831 << accType;
832
833 if (inputEType.isBF16() && !accType.isF32())
834 return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
835 << accType;
836
837 if (inputEType.isF32() && !accType.isF32())
838 return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
839 << accType;
840
841 auto resultEType =
842 llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
843
844 if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
845 resultEType = getStorageElementTypeFromQuantized(quantType);
846
847 return success();
848}
849
850//===----------------------------------------------------------------------===//
851// ERROR_IF functions.
852// ERROR_IF is a predicate that must set an error if the condition holds.
853//===----------------------------------------------------------------------===//
854
855template <typename T>
856static LogicalResult verifyConvOpErrorIf(T op) {
857 llvm::ArrayRef<int64_t> padding = op.getPad();
858 if (llvm::any_of(padding, [](int64_t p) { return p < 0; }))
859 return op.emitOpError("expect all padding values to be >= 0, got ")
860 << padding;
861
862 llvm::ArrayRef<int64_t> strides = op.getStride();
863 if (llvm::any_of(strides, [](int64_t s) { return s < 1; }))
864 return op.emitOpError("expect all stride values to be >= 1, got ")
865 << strides;
866
867 llvm::ArrayRef<int64_t> dilations = op.getDilation();
868 if (llvm::any_of(dilations, [](int64_t d) { return d < 1; }))
869 return op.emitOpError("expect all dilation values to be >= 1, got ")
870 << dilations;
871
872 const RankedTensorType outputType =
873 llvm::dyn_cast<RankedTensorType>(op.getOutput().getType());
874 if (!outputType)
875 // Skip following checks if output is not ranked
876 return success();
877
878 const RankedTensorType inputType =
879 llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
880 const RankedTensorType weightType =
881 llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
882
883 if (inputType && weightType) {
884 // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
885 if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
886 if (failed(verifyConvOutputSize(
887 op, inputType.getDimSize(1), weightType.getDimSize(1),
888 outputType.getDimSize(1), padding[0], padding[1], strides[0],
889 dilations[0], "height", "y", "top", "bottom")))
890 return failure();
891
892 if (failed(verifyConvOutputSize(
893 op, inputType.getDimSize(2), weightType.getDimSize(2),
894 outputType.getDimSize(2), padding[2], padding[3], strides[1],
895 dilations[1], "width", "x", "left", "right")))
896 return failure();
897 }
898
899 // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
900 if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
901 if (failed(verifyConvOutputSize(
902 op, inputType.getDimSize(1), weightType.getDimSize(0),
903 outputType.getDimSize(1), padding[0], padding[1], strides[0],
904 dilations[0], "height", "y", "top", "bottom")))
905 return failure();
906
907 if (failed(verifyConvOutputSize(
908 op, inputType.getDimSize(2), weightType.getDimSize(1),
909 outputType.getDimSize(2), padding[2], padding[3], strides[1],
910 dilations[1], "width", "x", "left", "right")))
911 return failure();
912 }
913
914 // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
915 if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
916 if (failed(verifyConvOutputSize(
917 op, inputType.getDimSize(1), weightType.getDimSize(1),
918 outputType.getDimSize(1), padding[0], padding[1], strides[0],
919 dilations[0], "depth", "d", "front", "back")))
920 return failure();
921
922 if (failed(verifyConvOutputSize(
923 op, inputType.getDimSize(2), weightType.getDimSize(2),
924 outputType.getDimSize(2), padding[2], padding[3], strides[1],
925 dilations[1], "height", "y", "top", "bottom")))
926 return failure();
927
928 if (failed(verifyConvOutputSize(
929 op, inputType.getDimSize(3), weightType.getDimSize(3),
930 outputType.getDimSize(3), padding[4], padding[5], strides[2],
931 dilations[2], "width", "x", "left", "right")))
932 return failure();
933 }
934 }
935
936 const RankedTensorType biasType =
937 llvm::dyn_cast<RankedTensorType>(op.getBias().getType());
938 if (!biasType)
939 // Skip following checks if bias is not ranked
940 return success();
941
942 const int64_t biasChannels = biasType.getDimSize(0);
943 const int64_t outputChannels =
944 outputType.getDimSize(outputType.getRank() - 1);
945 if (biasChannels == ShapedType::kDynamic ||
946 outputChannels == ShapedType::kDynamic)
947 // Skip following checks if biasChannels or outputChannels is dynamic dim
948 return success();
949
950 if (biasChannels != outputChannels && biasChannels != 1)
951 return op.emitOpError(
952 "bias channels expected to be equal to output channels (")
953 << outputChannels << ") or 1, got " << biasChannels;
954
955 return success();
956}
957
958// Verify whether same type and shape of the given two types.
959static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1,
960 StringRef name1, Type type2,
961 StringRef name2) {
962 auto shapeType1 = dyn_cast<ShapedType>(type1);
963 auto shapeType2 = dyn_cast<ShapedType>(type2);
964 if (!shapeType1 || !shapeType2)
965 return failure();
966
967 auto elemType1 = shapeType1.getElementType();
968 auto elemType2 = shapeType2.getElementType();
969 if (elemType1 != elemType2)
970 return op->emitOpError()
971 << "require same element type for " << name1 << " (" << elemType1
972 << ") and " << name2 << " (" << elemType2 << ")";
973
974 if (failed(verifyCompatibleShape(type1, type2)))
975 return op->emitOpError()
976 << "require same shapes for " << name1 << " (" << type1 << ") and "
977 << name2 << " (" << type2 << ")";
978
979 return success();
980}
981
982// Verify whether same length, type, and shape of the given two tensor lists.
983static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1,
984 StringRef name1,
985 ValueRange list2,
986 StringRef name2) {
987 if (list1.size() != list2.size())
988 return op->emitOpError()
989 << "require same number of values in " << name1 << " ("
990 << list1.size() << ") and " << name2 << " (" << list2.size() << ")";
991
992 for (auto [type1, type2] :
993 llvm::zip_equal(list1.getTypes(), list2.getTypes())) {
994 if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed())
995 return failure();
996 }
997
998 return success();
999}
1000
1001static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) {
1002 ShapeAdaptor shapeAdaptor(type);
1003 if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape())
1004 return success();
1005
1006 return shapeAdaptor.getNumElements() == 1 ? success() : failure();
1007}
1008
1009template <typename T>
1010static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
1011 Operation *symTableOp =
1012 op->template getParentWithTrait<OpTrait::SymbolTable>();
1013 if (!symTableOp)
1014 // If the operation is not the scope of a symbol table, we cannot
1015 // verify it against it's declaration.
1016 return success();
1017
1018 SymbolTable symTable(symTableOp);
1019 const auto varOp = symTable.lookup<tosa::VariableOp>(op.getName());
1020
1021 // Verify prior declaration
1022 if (!varOp)
1023 return op->emitOpError("'")
1024 << op.getName() << "' has not been declared by 'tosa.variable'";
1025
1026 // Verify type and shape
1027 auto variableType = getVariableType(varOp);
1028 if (errorIfTypeOrShapeMismatch(op, type, name, variableType,
1029 "the input tensor")
1030 .failed())
1031 return failure();
1032 return success();
1033}
1034
1035// verify that inType and outType have same element types
1036template <typename T>
1037static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
1038 StringRef aName = "input",
1039 StringRef bName = "output") {
1040 auto aTType = llvm::dyn_cast<TensorType>(aType);
1041 auto bTType = llvm::dyn_cast<TensorType>(bType);
1042 if (!aTType) {
1043 op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
1044 return failure();
1045 }
1046 if (!bTType) {
1047 op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
1048 return failure();
1049 }
1050 auto aElementType = aTType.getElementType();
1051 auto bElementType = bTType.getElementType();
1052 auto aQuantType =
1053 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
1054 auto bQuantType =
1055 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
1056 if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
1057 (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
1058 aElementType != bElementType) {
1059 // only check if both element types are int/index/float/UniformQuantized
1060 // eg, not sure how to check quant::QuantizedType
1061 // this happens in test_conv2d_q_grouped_convolution in
1062 // tfl-to-tosa-pipeline.mlir
1063 op.emitOpError("expect ")
1064 << aName << " and " << bName << " to have same element type, got "
1065 << aElementType << " and " << bElementType;
1066 return failure();
1067 }
1068 return success();
1069}
1070
1071LogicalResult tosa::ArgMaxOp::verify() {
1072 const ShapedType resultType = llvm::cast<ShapedType>(getType());
1073
1074 // Ensure output is of 32-bit integer
1075 if (const auto resultETy = resultType.getElementType();
1076 !resultETy.isIntOrIndex())
1077 return emitOpError("result tensor is not of integer type");
1078
1079 const auto inputType = llvm::cast<ShapedType>(getInput().getType());
1080 if (!inputType.hasRank())
1081 return success();
1082
1083 // Ensure axis is within the tensor rank
1084 const int64_t axis = getAxisAttr().getInt();
1085 if (((axis < 0) || axis >= inputType.getRank()))
1086 return emitOpError("specified axis is outside the rank of the tensor");
1087
1088 if (!resultType.hasRank())
1089 return success();
1090
1091 const ArrayRef<int64_t> inputShape = inputType.getShape();
1092 const ArrayRef<int64_t> outputShape = resultType.getShape();
1093 llvm::SmallVector<int64_t> expectedOutputShape(inputShape);
1094 expectedOutputShape.erase(expectedOutputShape.begin() + axis);
1095 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape)))
1096 return emitOpError("expected output shape '")
1097 << expectedOutputShape << "', got '" << outputShape << "'";
1098
1099 return success();
1100}
1101
1102static LogicalResult verifyPoolingOpImpl(Operation *op,
1103 ArrayRef<int64_t> kernel,
1104 ArrayRef<int64_t> strides,
1105 ArrayRef<int64_t> padding, Value input,
1106 Value output) {
1107 const bool hasKernel = kernel.size() > 0;
1108 const bool hasStrides = strides.size() > 0;
1109 const bool hasPad = padding.size() > 0;
1110
1111 if (hasKernel && llvm::any_of(kernel, [](int64_t s) { return s < 1; }))
1112 return op->emitOpError("expect all kernel values to be >= 1, got ")
1113 << kernel;
1114
1115 if (hasStrides && llvm::any_of(strides, [](int64_t s) { return s < 1; }))
1116 return op->emitOpError("expect all stride values to be >= 1, got ")
1117 << strides;
1118
1119 if (hasPad && llvm::any_of(padding, [](int64_t p) { return p < 0; }))
1120 return op->emitOpError("expect all padding values to be >= 0, got ")
1121 << padding;
1122
1123 if (hasKernel && hasPad) {
1124 // Padding must be less than kernel size to avoid a divide-by-zero
1125 const int64_t kernelX = kernel[1];
1126 const int64_t padLeft = padding[2];
1127 const int64_t padRight = padding[3];
1128 if (padRight >= kernelX || padLeft >= kernelX)
1129 return op->emitOpError("expected left/right padding to be less than the "
1130 "width of the kernel, got pad_left=")
1131 << padLeft << ", pad_right=" << padRight
1132 << ", kernel_x=" << kernelX;
1133
1134 const int64_t kernelY = kernel[0];
1135 const int64_t padTop = padding[0];
1136 const int64_t padBottom = padding[1];
1137 if (padTop >= kernelY || padBottom >= kernelY)
1138 return op->emitOpError("expected top/bottom padding to be less than the "
1139 "height of the kernel, got pad_top=")
1140 << padTop << ", pad_bottom=" << padBottom
1141 << ", kernel_y=" << kernelY;
1142 }
1143
1144 const auto inputType = llvm::dyn_cast<RankedTensorType>(input.getType());
1145 const auto outputType = llvm::dyn_cast<RankedTensorType>(output.getType());
1146 if (!inputType || !outputType)
1147 return success();
1148
1149 if (hasKernel && hasStrides && hasPad) {
1150 const auto verifyOutputSize =
1151 [op](const int64_t inputSize, const int64_t outputSize,
1152 const int64_t kernelSize, const int64_t strideSize,
1153 const int64_t padBefore, const int64_t padAfter,
1154 const llvm::StringRef dimName, const llvm::StringRef dimAxis,
1155 const llvm::StringRef padBeforeName,
1156 const llvm::StringRef padAfterName) -> LogicalResult {
1157 if (ShapedType::isDynamic(inputSize))
1158 return success();
1159
1160 const std::optional<int64_t> calculatedOutSizeMinusOne =
1161 idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize);
1162 if (!calculatedOutSizeMinusOne.has_value())
1163 return op->emitOpError("expected input_")
1164 << dimName << " + pad_" << padBeforeName << " + pad_"
1165 << padAfterName << " - kernel_" << dimAxis
1166 << " to be wholly divisible by stride_" << dimAxis << ", got ("
1167 << inputSize << " + " << padBefore << " + " << padAfter << " - "
1168 << kernelSize << ") / " << strideSize;
1169
1170 const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
1171 if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
1172 return op->emitOpError("calculated output ")
1173 << dimName << " did not match expected: " << "calculated="
1174 << calculatedOutSize << ", expected=" << outputSize;
1175
1176 return success();
1177 };
1178
1179 if (failed(verifyOutputSize(inputType.getDimSize(1),
1180 outputType.getDimSize(1), kernel[0], strides[0],
1181 padding[0], padding[1], "height", "y", "top",
1182 "bottom")))
1183 return failure();
1184
1185 if (failed(verifyOutputSize(
1186 inputType.getDimSize(2), outputType.getDimSize(2), kernel[1],
1187 strides[1], padding[2], padding[3], "width", "x", "left", "right")))
1188 return failure();
1189 }
1190 return success();
1191}
1192
1193template <typename T>
1194static LogicalResult verifyPoolingOp(T op) {
1195 return verifyPoolingOpImpl(op.getOperation(), op.getKernel(), op.getStride(),
1196 op.getPad(), op.getInput(), op.getOutput());
1197}
1198
1199template <typename T>
1200static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op) {
1201 const Type inputETy = getStorageElementTypeOrSelf(op.getInput().getType());
1202 const Type resultETy = getStorageElementTypeOrSelf(op.getOutput().getType());
1203 const Type inputZpETy =
1204 getStorageElementTypeOrSelf(op.getInputZp().getType());
1205 const Type outputZpETy =
1206 getStorageElementTypeOrSelf(op.getOutputZp().getType());
1207
1208 auto accType = op.getAccType();
1209 if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
1210 return op.emitOpError("accumulator type for integer tensor is not i32");
1211
1212 if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
1213 return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
1214
1215 if (inputETy.isBF16() && !accType.isF32())
1216 return op.emitOpError("accumulator type for bf16 tensor is not f32");
1217
1218 if (inputETy.isF32() && !accType.isF32())
1219 return op.emitOpError("accumulator type for f32 tensor is not f32");
1220
1221 if (inputETy != inputZpETy)
1222 return op.emitOpError("expect both input and its zero point are the same "
1223 "element type, got ")
1224 << inputETy << " and " << inputZpETy;
1225
1226 if (resultETy != outputZpETy)
1227 return op.emitOpError("expect both output and its zero point are the same "
1228 "element type, got ")
1229 << resultETy << " and " << outputZpETy;
1230
1231 FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1232 if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed())
1233 return failure();
1234
1235 FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1236 if (succeeded(maybeOZp) && op.verifyOutputZeroPoint(*maybeOZp).failed())
1237 return failure();
1238
1239 return success();
1240}
1241
1242namespace {
1243struct AdaptivePoolingConstShapeValues {
1244 llvm::SmallVector<int64_t> kernel;
1245 llvm::SmallVector<int64_t> stride;
1246 llvm::SmallVector<int64_t> pad;
1247};
1248} // namespace
1249
1250template <typename T>
1252 std::is_same_v<T, tosa::AvgPool2dAdaptiveOp> ||
1253 std::is_same_v<T, tosa::MaxPool2dAdaptiveOp>;
1254
1255template <typename T,
1256 typename std::enable_if<IsSupportedAdaptivePoolConstShapeVerifyOp<T>,
1257 int>::type = 0>
1259 T op, AdaptivePoolingConstShapeValues &values) {
1260 tosa::getConstShapeValues(op.getKernel().getDefiningOp(), values.kernel);
1261 tosa::getConstShapeValues(op.getStride().getDefiningOp(), values.stride);
1262 tosa::getConstShapeValues(op.getPad().getDefiningOp(), values.pad);
1263}
1264
1265LogicalResult tosa::AvgPool2dOp::verify() {
1266 if (failed(verifyPoolingOp(*this)))
1267 return failure();
1269 return failure();
1270 return success();
1271}
1272
1273LogicalResult tosa::AvgPool2dAdaptiveOp::verify() {
1274 AdaptivePoolingConstShapeValues values;
1276
1277 // If pad/stride/kernel are not constant, this is okay, we just can't check
1278 // their values. extractAdaptivePoolingConstShapeOperands will return an empty
1279 // list for each non CTC input. verifyPoolingOpImpl will need to handle values
1280 // not being present, and return success if they cannot be checked.
1281
1282 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
1283 values.pad, getInput(), getOutput())))
1284 return failure();
1285
1287 return failure();
1288
1289 return success();
1290}
1291
1292LogicalResult tosa::ClampOp::verify() {
1293 mlir::Type inputETy =
1294 llvm::cast<ShapedType>(getInput().getType()).getElementType();
1295 if (auto quantType =
1296 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
1297 inputETy = getStorageElementTypeFromQuantized(quantType);
1298 }
1299 mlir::Type outputETy =
1300 llvm::cast<ShapedType>(getOutput().getType()).getElementType();
1301 if (auto quantType =
1302 llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
1303 outputETy = getStorageElementTypeFromQuantized(quantType);
1304 }
1305 if (inputETy != outputETy)
1306 return emitOpError("input/output element types are incompatible.");
1307
1308 auto maxValAttr = getMaxValAttr();
1309 auto minValAttr = getMinValAttr();
1310
1311 unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth();
1312
1313 if (inputETy.isInteger(dataTypeBitWidth)) {
1314 // if input datatype is integer, check that the min_val/max_val attributes
1315 // are integer attributes, and that their type is the same as the input's
1316 // datatype
1317 auto intMaxValAttr = mlir::dyn_cast<mlir::IntegerAttr>(maxValAttr);
1318 auto intMinValAttr = mlir::dyn_cast<mlir::IntegerAttr>(minValAttr);
1319 if (!intMaxValAttr || !intMinValAttr ||
1320 (intMaxValAttr.getType() != intMinValAttr.getType()) ||
1321 (intMaxValAttr.getType() != inputETy))
1322 return emitOpError("min/max attributes types are incompatible with "
1323 "input/output element types.");
1324
1325 const bool isUnsigned = inputETy.isUnsignedInteger();
1326 const bool isBoolean = inputETy.isInteger(1);
1327 const APInt minVal = intMinValAttr.getValue();
1328 const APInt maxVal = intMaxValAttr.getValue();
1329 if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal))
1330 return emitOpError("expected min_val <= max_val, got min_val=")
1331 << minValAttr << ", max_val=" << maxValAttr;
1332 } else {
1333 // otherwise, input datatype is float, check that the min_val/max_val
1334 // attributes share the same type and that their type is the same as the
1335 // input's datatype
1336 auto floatMaxValAttr = mlir::dyn_cast<mlir::FloatAttr>(maxValAttr);
1337 auto floatMinValAttr = mlir::dyn_cast<mlir::FloatAttr>(minValAttr);
1338 if (!floatMaxValAttr || !floatMinValAttr ||
1339 (floatMaxValAttr.getType() != floatMinValAttr.getType()) ||
1340 (floatMaxValAttr.getType() != inputETy))
1341 return emitOpError("min/max attributes types are incompatible with "
1342 "input/output element types.");
1343
1344 const APFloat minVal = floatMinValAttr.getValue();
1345 const APFloat maxVal = floatMaxValAttr.getValue();
1346 if (minVal.isNaN() || maxVal.isNaN())
1347 return emitOpError("min/max attributes should not be 'NaN', got min_val=")
1348 << minValAttr << ", max_val=" << maxValAttr;
1349
1350 if (maxVal < minVal)
1351 return emitOpError("expected min_val <= max_val, got min_val=")
1352 << minValAttr << ", max_val=" << maxValAttr;
1353 }
1354
1355 return success();
1356}
1357
1358//===----------------------------------------------------------------------===//
1359// TOSA Operator Quantization Builders.
1360//===----------------------------------------------------------------------===//
1361
1362/// This builder is called on all convolution operators except TransposeConv,
1363/// which has specialized output shape semantics. The builder also defines the
1364/// bitwidth of the output given the bit width of the input & weight content.
1366 Type outputType, Value input, Value weight,
1367 Value bias, DenseI64ArrayAttr pad,
1368 DenseI64ArrayAttr stride,
1369 DenseI64ArrayAttr dilation,
1370 TypeAttr accType) {
1371 auto zps = createZPsAsConst(builder, input, weight);
1372 result.addOperands({input, weight, bias, zps.first, zps.second});
1373 result.addAttribute("pad", pad);
1374 result.addAttribute("stride", stride);
1375 result.addAttribute("dilation", dilation);
1376 result.addAttribute("acc_type", accType);
1377 Type finalOutputType = outputType;
1378 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1379 if (quantAttr) {
1380 finalOutputType =
1381 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1382 }
1383 result.addTypes(finalOutputType);
1384}
1385
1386/// Handles tosa.transpose_conv2d which has outpad and output shape
1387/// attributes.
1388static void
1390 Type outputType, Value input, Value weight,
1391 Value bias, DenseI64ArrayAttr outpad,
1392 DenseI64ArrayAttr stride, TypeAttr accType) {
1393 auto zps = createZPsAsConst(builder, input, weight);
1394 result.addOperands({input, weight, bias, zps.first, zps.second});
1395 result.addAttribute("out_pad", outpad);
1396 result.addAttribute("stride", stride);
1397 result.addAttribute("acc_type", accType);
1398 Type finalOutputType = outputType;
1399 auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
1400 if (quantAttr) {
1401 finalOutputType =
1402 buildConvOpResultTypeInfo(builder, outputType, input, weight);
1403 }
1404 result.addTypes(finalOutputType);
1405}
1406
1407/// The tosa.matmul op is also intended to be generated where a fully_connected
1408/// op must be constructed where the weight is not a constant. In this case,
1409/// the fully_connected op must be expressed using matmul.
1410/// TODO: Add link to the leglization document explaining this.
1412 OperationState &result, Type outputType,
1413 Value a, Value b) {
1414 auto zps = createZPsAsConst(builder, a, b);
1415 result.addOperands({a, b, zps.first, zps.second});
1416
1417 Type finalOutputType{outputType};
1418 if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
1419 auto eType = getStorageElementTypeOrSelf(a.getType());
1420 auto inputBits = eType.getIntOrFloatBitWidth();
1421
1422 auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
1423 assert(outputShapedType && "Output must be a shaped type");
1424
1425 IntegerType accElementType;
1426 if (inputBits == 16)
1427 accElementType = builder.getIntegerType(48);
1428 else
1429 accElementType = builder.getI32Type();
1430
1431 finalOutputType = outputShapedType.clone(accElementType);
1432 }
1433 result.addTypes(finalOutputType);
1434}
1435
1436/// Both the tosa.avg_pool2d and unary ops use the same
1437/// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it
1438/// has additional parameters not part of the unary ops.
1439static void
1441 Type outputType, Value input,
1442 DenseArrayAttr kernel, DenseArrayAttr stride,
1443 DenseArrayAttr pad, TypeAttr accType) {
1444 const Location loc{result.location};
1445 int64_t inputZp{0};
1446 int64_t outputZp{0};
1447
1448 if (auto quantAttr =
1449 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1450 inputZp = quantAttr.getInputZp();
1451 outputZp = quantAttr.getOutputZp();
1452 }
1453 const std::optional<Value> inputZpOp =
1454 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1455 if (!inputZpOp) {
1456 (void)emitError(
1457 loc,
1458 "Failed to create input zero point tensor for quantized AVG_POOL2D op");
1459 }
1460 const std::optional<Value> outputZpOp =
1461 createZeroPointTensor(builder, loc, outputType, outputZp);
1462 if (!outputZpOp) {
1463 (void)emitError(loc, "Failed to create output zero point tensor for "
1464 "quantized AVG_POOL2D op");
1465 }
1466
1467 if (inputZpOp && outputZpOp) {
1468 result.addOperands({input, inputZpOp.value(), outputZpOp.value()});
1469 } else {
1470 // failed to create one or more zero points above: just add input as
1471 // operands this will trigger error in building the op because of missing
1472 // zero points
1473 result.addOperands({input});
1474 }
1475 result.addAttribute("kernel", kernel);
1476 result.addAttribute("stride", stride);
1477 result.addAttribute("pad", pad);
1478 result.addAttribute("acc_type", accType);
1479 result.types.push_back(outputType);
1480}
1481
1482/// This builder mirrors avg_pool2d quant-info handling and materializes
1483/// kernel/stride/pad as const_shape operands for avg_pool2d_adaptive.
1485 OpBuilder &builder, OperationState &result, Type outputType, Value input,
1487 TypeAttr accType) {
1488 const Location loc{result.location};
1489 int64_t inputZp{0};
1490 int64_t outputZp{0};
1491
1492 if (auto quantAttr =
1493 buildUnaryOpQuantizationAttr(builder, input, outputType)) {
1494 inputZp = quantAttr.getInputZp();
1495 outputZp = quantAttr.getOutputZp();
1496 }
1497 const std::optional<Value> inputZpOp =
1498 createZeroPointTensor(builder, loc, input.getType(), inputZp);
1499 if (!inputZpOp) {
1500 (void)emitError(loc,
1501 "Failed to create input zero point tensor for quantized "
1502 "AVG_POOL2D_ADAPTIVE op");
1503 }
1504 const std::optional<Value> outputZpOp =
1505 createZeroPointTensor(builder, loc, outputType, outputZp);
1506 if (!outputZpOp) {
1507 (void)emitError(loc, "Failed to create output zero point tensor for "
1508 "quantized AVG_POOL2D_ADAPTIVE op");
1509 }
1510
1511 if (inputZpOp && outputZpOp) {
1512 ImplicitLocOpBuilder b(loc, builder);
1513 Value kernelShape = getTosaConstShape(b, kernel.asArrayRef());
1514 Value strideShape = getTosaConstShape(b, stride.asArrayRef());
1515 Value padShape = getTosaConstShape(b, pad.asArrayRef());
1516 result.addOperands({input, inputZpOp.value(), outputZpOp.value(),
1517 kernelShape, strideShape, padShape});
1518 } else {
1519 // Failed to create one or more zero points above: just add input as
1520 // operands. This will trigger error in building the op because of missing
1521 // operands.
1522 result.addOperands({input});
1523 }
1524 result.addAttribute("acc_type", accType);
1525 result.types.push_back(outputType);
1526}
1527
1528/// This builder is called on single-parameter negate operator
1529/// to construct input and output zero points based on their
1530/// types.
1532 OperationState &result, Type outputType,
1533 Value input) {
1534 const Location loc{result.location};
1535 int64_t input1Zp{0};
1536 int64_t outputZp{0};
1537 auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
1538 if (quantAttr) {
1539 input1Zp = quantAttr.getInputZp();
1540 outputZp = quantAttr.getOutputZp();
1541 }
1542 const std::optional<Value> input1ZpOp =
1543 createZeroPointTensor(builder, loc, input.getType(), input1Zp);
1544 if (!input1ZpOp) {
1545 (void)emitError(
1546 loc, "Failed to create input1 zero point for quantized NEGATE op");
1547 }
1548
1549 const std::optional<Value> outputZpOp =
1550 createZeroPointTensor(builder, loc, input.getType(), outputZp);
1551 if (!outputZpOp) {
1552 (void)emitError(
1553 loc, "Failed to create output zero point for quantized NEGATE op");
1554 }
1555
1556 if (input1ZpOp && outputZpOp) {
1557 result.addOperands({input, input1ZpOp.value(), outputZpOp.value()});
1558 } else {
1559 // failed to create one or more zero points above: just add input as
1560 // operands. This will trigger error in building the op because of
1561 // missing zero points
1562 result.addOperands({input});
1563 }
1564
1565 result.types.push_back(outputType);
1566}
1567
1568/// This builder is called on TOSA pad operator that needs to create its own
1569/// OptionalAttr quantization_attr parameter to scale the padding values
1570/// correctly. No pad_const is interpreted as zero-padding.
1572 Type outputType, Value input,
1573 Value paddings) {
1574 const Location loc{result.location};
1575 int32_t zp{0};
1576 const auto quantAttr = buildPadOpQuantizationAttr(builder, input);
1577 if (quantAttr) {
1578 zp = static_cast<int32_t>(quantAttr.getInputZp());
1579 }
1580 const auto padConstOp{createPadConstTensor(builder, loc, input, zp)};
1581 result.addOperands({input, paddings, padConstOp});
1582 result.types.push_back(outputType);
1583}
1584
1586 StringRef name, Type variableType,
1587 Attribute initialValue) {
1588 const Location loc{result.location};
1589 auto nameAttr = builder.getStringAttr(name);
1590
1591 auto shapedType = dyn_cast<ShapedType>(variableType);
1592 if (!shapedType) {
1593 (void)emitError(loc, "variable type must be a shaped type");
1594 return;
1595 }
1596 if (!shapedType.hasRank()) {
1597 (void)emitError(loc, "variable type must be a ranked type");
1598 return;
1599 }
1600
1601 auto elementType = shapedType.getElementType();
1602 auto elementTypeAttr = TypeAttr::get(elementType);
1603 ArrayRef<int64_t> shape = shapedType.getShape();
1604 auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape));
1605
1606 result.addAttribute("sym_name", nameAttr);
1607 result.addAttribute("var_shape", varShapeAttr);
1608 result.addAttribute("type", elementTypeAttr);
1609 result.addAttribute("initial_value", initialValue);
1610}
1611
1612//===----------------------------------------------------------------------===//
1613// TOSA Operator Return Type Inference.
1614//===----------------------------------------------------------------------===//
1615static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
1616 const int64_t dim2) {
1617 if (dim1 == 1)
1618 return dim2;
1619 if (dim2 == 1)
1620 return dim1;
1621
1622 if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
1623 return failure();
1624
1625 // Prefer static dimension over dynamic
1626 return ShapedType::isDynamic(dim1) ? dim2 : dim1;
1627}
1628
1629static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
1630 SmallVector<int64_t> &outShape) {
1631 int64_t outRank = 0;
1632 for (int i = 0, e = operands.size(); i != e; ++i) {
1633 auto shape = operands.getShape(i);
1634 if (!shape.hasRank()) {
1635 // TODO(jennik): Update function to have better case handling for
1636 // invalid operands and for ranked tensors.
1637 return failure();
1638 }
1639 outRank = std::max<int64_t>(outRank, shape.getRank());
1640 }
1641
1642 outShape.resize(outRank, 1);
1643
1644 for (int i = 0, e = operands.size(); i != e; ++i) {
1645 auto shape = operands.getShape(i);
1646 auto rankDiff = outShape.size() - shape.getRank();
1647
1648 for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
1649 auto dim1 = outShape[i + rankDiff];
1650 auto dim2 = shape.getDimSize(i);
1651
1652 const FailureOr<int64_t> maybeResolvedDim =
1653 resolveBroadcastDim(dim1, dim2);
1654 if (failed(maybeResolvedDim))
1655 return failure();
1656 const int64_t resolvedDim = *maybeResolvedDim;
1657 outShape[i + rankDiff] = resolvedDim;
1658 }
1659 }
1660
1661 return success();
1662}
1663
1664LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents(
1665 MLIRContext *context, ::std::optional<Location> location,
1666 ArgMaxOp::Adaptor adaptor,
1667 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1668 ShapeAdaptor inputShape(adaptor.getInput().getType());
1669 IntegerAttr axis = adaptor.getProperties().axis;
1670 int32_t axisVal = axis.getValue().getSExtValue();
1671
1672 if (!inputShape.hasRank()) {
1673 inferredReturnShapes.push_back(ShapedTypeComponents());
1674 return success();
1675 }
1676
1677 SmallVector<int64_t> outShape;
1678 outShape.reserve(inputShape.getRank() - 1);
1679 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
1680 if (i == axisVal)
1681 continue;
1682 outShape.push_back(inputShape.getDimSize(i));
1683 }
1684
1685 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
1686 return success();
1687}
1688
1689LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents(
1690 MLIRContext *context, ::std::optional<Location> location,
1691 RFFT2dOp::Adaptor adaptor,
1692 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1693 ShapeAdaptor inputShape(adaptor.getInputReal().getType());
1694
1695 if (!inputShape.hasRank())
1696 return failure();
1697
1698 llvm::SmallVector<int64_t> outputShape;
1699 outputShape.resize(3, ShapedType::kDynamic);
1700 outputShape[0] = inputShape.getDimSize(0);
1701 outputShape[1] = inputShape.getDimSize(1);
1702 int64_t inWidth = inputShape.getDimSize(2);
1703
1704 // Note that we can support this calculation symbolically
1705 // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1]
1706 if (inWidth != ShapedType::kDynamic)
1707 outputShape[2] = inWidth / 2 + 1;
1708
1709 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1710 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
1711
1712 return success();
1713}
1714
1715static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize,
1716 const llvm::StringRef dimName) {
1717 const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0;
1718 if (!isPowerOfTwo)
1719 return op->emitOpError("expected ")
1720 << dimName << " to be a power of two, got " << dimSize;
1721
1722 return success();
1723}
1724
1725LogicalResult tosa::RFFT2dOp::verify() {
1726 const auto outputTypes = getResultTypes();
1727 if (failed(verifyCompatibleShapes(outputTypes)))
1728 return emitOpError("expected output shapes to match, got ") << outputTypes;
1729
1730 const auto inputType =
1731 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1732 if (!inputType)
1733 return success();
1734
1735 const int64_t height = inputType.getDimSize(1);
1736 if (ShapedType::isStatic(height) &&
1737 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1738 return failure();
1739
1740 const int64_t width = inputType.getDimSize(2);
1741 if (ShapedType::isStatic(width) &&
1742 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1743 return failure();
1744
1745 const auto outputType = llvm::dyn_cast<RankedTensorType>(outputTypes[0]);
1746 if (!outputType)
1747 return success();
1748
1749 // Batch and height input/output dimensions should match
1750 if (failed(verifyCompatibleShape(inputType.getShape().drop_back(),
1751 outputType.getShape().drop_back())))
1752 return emitOpError("expected batch and height dimensions of input/output "
1753 "to match, got input=")
1754 << inputType << " output=" << outputType;
1755
1756 // Output width dimension expected to be input_width / 2 + 1
1757 const int64_t outputWidth = outputType.getDimSize(2);
1758 if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) &&
1759 (outputWidth != (width / 2) + 1))
1760 return emitOpError(
1761 "expected output width to be equal to input_width / 2 + 1, got ")
1762 << outputWidth;
1763
1764 return success();
1765}
1766
1767LogicalResult tosa::FFT2dOp::inferReturnTypeComponents(
1768 MLIRContext *context, ::std::optional<Location> location,
1769 FFT2dOp::Adaptor adaptor,
1770 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1771 inferredReturnShapes.push_back(
1772 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType())));
1773 inferredReturnShapes.push_back(
1774 ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType())));
1775 return success();
1776}
1777
1778LogicalResult tosa::FFT2dOp::verify() {
1779 const auto inputRealType =
1780 llvm::dyn_cast<RankedTensorType>(getInputReal().getType());
1781 const auto inputImagType =
1782 llvm::dyn_cast<RankedTensorType>(getInputImag().getType());
1783 if (!inputRealType || !inputImagType)
1784 return success();
1785
1786 const auto trySelectStaticDim = [](const int64_t a, const int64_t b) {
1787 return ShapedType::isDynamic(a) ? a : b;
1788 };
1789
1790 const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1),
1791 inputImagType.getDimSize(1));
1792 if (ShapedType::isStatic(height) &&
1793 failed(verifyDimIsPowerOfTwo(*this, height, "height")))
1794 return failure();
1795
1796 const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2),
1797 inputImagType.getDimSize(2));
1798 if (ShapedType::isStatic(width) &&
1799 failed(verifyDimIsPowerOfTwo(*this, width, "width")))
1800 return failure();
1801
1802 return success();
1803}
1804
1805LogicalResult tosa::ConcatOp::inferReturnTypeComponents(
1806 MLIRContext *context, ::std::optional<Location> location,
1807 ConcatOp::Adaptor adaptor,
1808 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1809 // Infer all dimension sizes by reducing based on inputs.
1810 const Properties &prop = adaptor.getProperties();
1811 int32_t axis = prop.axis.getValue().getSExtValue();
1812 llvm::SmallVector<int64_t> outputShape;
1813 bool hasRankedInput = false;
1814 for (auto operand : adaptor.getOperands()) {
1815 ShapeAdaptor operandShape(operand.getType());
1816 if (!operandShape.hasRank())
1817 continue;
1818
1819 // Copy the Operand's rank.
1820 if (!hasRankedInput)
1821 outputShape.resize(operandShape.getRank(), ShapedType::kDynamic);
1822
1823 // Copy shapes until the dim is non-dynamic.
1824 for (int i = 0, s = operandShape.getRank(); i < s; i++) {
1825 if (i == axis || operandShape.isDynamicDim(i))
1826 continue;
1827 if (outputShape[i] == ShapedType::kDynamic)
1828 outputShape[i] = operandShape.getDimSize(i);
1829 if (outputShape[i] != operandShape.getDimSize(i))
1830 return emitOptionalError(location,
1831 "Cannot concat tensors with different sizes"
1832 " on the non-axis dimension ",
1833 i);
1834 }
1835
1836 hasRankedInput = true;
1837 }
1838
1839 if (adaptor.getInput1().empty())
1840 return failure();
1841
1842 Type inputType =
1843 llvm::cast<TensorType>(adaptor.getInput1().getType()[0]).getElementType();
1844 if (!hasRankedInput) {
1845 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
1846 return success();
1847 }
1848
1849 // Determine the dimension size along the concatenation axis.
1850 int64_t concatDimSize = 0;
1851 for (auto operand : adaptor.getOperands()) {
1852 ShapeAdaptor operandShape(operand.getType());
1853
1854 // We need to know the length of the concatenation axis of all inputs to
1855 // determine the dimension size of the output shape.
1856 if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) {
1857 concatDimSize = ShapedType::kDynamic;
1858 break;
1859 }
1860
1861 concatDimSize += operandShape.getDimSize(axis);
1862 }
1863
1864 outputShape[axis] = concatDimSize;
1865
1866 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
1867 return success();
1868}
1869
1870LogicalResult tosa::ConcatOp::verify() {
1871 // check that each input has same element type as output
1872 auto outType = getOutput().getType();
1873 const Operation::operand_range inputList = getInput1();
1874
1875 // Check there is at least one input
1876 if (inputList.empty())
1877 return emitOpError("expect at least one input");
1878
1879 if (!llvm::all_of(inputList, [&](auto input) {
1880 return succeeded(verifySameElementTypes(
1881 *this, /* inType = */ input.getType(), outType));
1882 })) {
1883 return failure();
1884 }
1885
1886 const int32_t axis = getAxis();
1887 ShapeAdaptor firstRankedInputShape = nullptr;
1888 for (const auto &input : inputList) {
1889 const Type inputType = input.getType();
1890 ShapeAdaptor currShape(inputType);
1891 if (currShape.hasRank()) {
1892 firstRankedInputShape = currShape;
1893 // Check axis is in expected range
1894 if (axis < 0 || axis >= firstRankedInputShape.getRank())
1895 return emitOpError("expect axis to be within range 0 < axis < "
1896 "rank(input1[firstRankedTensorIdx]), got ")
1897 << axis;
1898 break;
1899 }
1900 }
1901
1902 const auto allOperandsHasRank = [](const Value input) {
1903 return ShapeAdaptor(input.getType()).hasRank();
1904 };
1905 if (llvm::all_of(inputList, allOperandsHasRank)) {
1906 const int64_t firstInputRank = firstRankedInputShape.getRank();
1907
1908 for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) {
1909 const ShapeAdaptor inputShape(input.getType());
1910 const int64_t inputRank = inputShape.getRank();
1911 const size_t operandNum = index + 1;
1912
1913 // Check that each operand has the same rank
1914 if (inputRank != firstInputRank)
1915 return emitOpError(
1916 "expect all operands to have the same rank, but got ")
1917 << firstInputRank << " vs " << inputRank << " on operands 0 and "
1918 << operandNum;
1919
1920 // Check non-axis dims match
1921 for (int i = 0; i < inputRank; i++) {
1922 const int64_t inputDim = inputShape.getDimSize(i);
1923 const int64_t firstInputDim = firstRankedInputShape.getDimSize(i);
1924 if (i == axis || firstRankedInputShape.isDynamicDim(i) ||
1925 inputShape.isDynamicDim(i))
1926 continue;
1927 if (inputDim != firstInputDim)
1928 return emitOpError("expect all operand shapes to have the same sizes "
1929 "on non-axis dimensions, but got ")
1930 << inputDim << " vs " << firstInputDim << " at index " << i
1931 << " on operands 0 and " << operandNum;
1932 }
1933 }
1934
1935 const ShapeAdaptor outputShape(outType);
1936 if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
1937 return emitOpError("expect output rank to match inputs rank, got ")
1938 << outputShape.getRank() << " vs " << firstInputRank;
1939
1940 // ERROR_IF(axis_sum != shape[axis]);
1941 int64_t axisSum = 0;
1942 for (const auto &input : inputList) {
1943 const ShapeAdaptor inputShape(input.getType());
1944 if (inputShape.isDynamicDim(axis)) {
1945 // make axisSum negative to indicate invalid value
1946 axisSum = -1;
1947 break;
1948 }
1949 axisSum += inputShape.getDimSize(axis);
1950 }
1951
1952 if (axisSum >= 0 && outputShape.hasRank() &&
1953 !outputShape.isDynamicDim(axis) &&
1954 axisSum != outputShape.getDimSize(axis))
1955 return emitOpError("requires sum of axis dimensions of input1 "
1956 "equal to output axis dimension, got ")
1957 << axisSum << " and " << outputShape.getDimSize(axis);
1958 }
1959
1960 return success();
1961}
1962
1963LogicalResult tosa::EqualOp::inferReturnTypeComponents(
1964 MLIRContext *context, ::std::optional<Location> location,
1965 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
1966 RegionRange regions,
1967 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1968 auto elementType = IntegerType::get(context, /*width=*/1);
1969
1971 if (resolveBroadcastShape(operands, outShape).failed()) {
1972 inferredReturnShapes.push_back(ShapedTypeComponents(elementType));
1973 return success();
1974 }
1975
1976 inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType));
1977 return success();
1978}
1979
1980bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1981 if (l.size() != r.size() || l.size() != 1)
1982 return false;
1983 return succeeded(verifyCompatibleShape(l[0], r[0]));
1984}
1985
1986LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
1987 MLIRContext *context, ::std::optional<Location> location,
1988 MatMulOp::Adaptor adaptor,
1989 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1990 ShapeAdaptor lhsShape(adaptor.getA().getType());
1991 ShapeAdaptor rhsShape(adaptor.getB().getType());
1992
1993 // All shapes are dynamic.
1994 SmallVector<int64_t> outShape;
1995 outShape.resize(3, ShapedType::kDynamic);
1996
1997 if (lhsShape.hasRank()) {
1998 outShape[0] = lhsShape.getDimSize(0);
1999 outShape[1] = lhsShape.getDimSize(1);
2000 }
2001
2002 if (rhsShape.hasRank()) {
2003 outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0)
2004 : outShape[0];
2005 outShape[2] = rhsShape.getDimSize(2);
2006 }
2007
2008 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2009 return success();
2010}
2011
2012LogicalResult MatMulOp::verify() {
2013 const ShapeAdaptor aShape(getA().getType());
2014 const ShapeAdaptor bShape(getB().getType());
2015 const Type aElementType = aShape.getElementType();
2016 const Type bElementType = bShape.getElementType();
2017
2018 const auto aQuantizedEType =
2019 llvm::dyn_cast<quant::UniformQuantizedType>(aElementType);
2020 const auto bQuantizedEType =
2021 llvm::dyn_cast<quant::UniformQuantizedType>(bElementType);
2022
2023 if (aQuantizedEType || bQuantizedEType) {
2024 if (!aQuantizedEType || !bQuantizedEType) {
2025 return emitOpError("expect operands to be both quantized or both not "
2026 "quantized, got ")
2027 << aElementType << " and " << bElementType;
2028 }
2029 // both a and b have quantized element types
2030 auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth();
2031 auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth();
2032 if (aQuantWidth != bQuantWidth) {
2033 return emitOpError("expect quantized operands to have same widths, got ")
2034 << aQuantWidth << " and " << bQuantWidth;
2035 }
2036 }
2037
2038 // check a_zp and b_zp
2039 auto aEType = getStorageElementTypeOrSelf(aElementType);
2040 auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
2041 if (aEType != aZpEType)
2042 return emitOpError("expect input a and a_zp have the same "
2043 "element type, got ")
2044 << aEType << " and " << aZpEType;
2045
2046 const Type bEType = getStorageElementTypeOrSelf(bElementType);
2047 const Type bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
2048 if (bEType != bZpEType)
2049 return emitOpError("expect input b and b_zp have the same "
2050 "element type, got ")
2051 << bEType << " and " << bZpEType;
2052
2053 FailureOr<int64_t> maybeAZp = getAZeroPoint();
2054 if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
2055 return failure();
2056
2057 FailureOr<int64_t> maybeBZp = getBZeroPoint();
2058 if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
2059 return failure();
2060
2061 // Verify input/output shapes
2062 int64_t N = ShapedType::kDynamic;
2063 int64_t H = ShapedType::kDynamic;
2064 int64_t W = ShapedType::kDynamic;
2065 int64_t C = ShapedType::kDynamic;
2066
2067 if (aShape.hasRank()) {
2068 N = aShape.getDimSize(0);
2069 H = aShape.getDimSize(1);
2070 C = aShape.getDimSize(2);
2071 }
2072
2073 if (bShape.hasRank()) {
2074 if (failed(tryUpdateDimOrFailure(*this, N, bShape.getDimSize(0), "b",
2075 "batch")) ||
2076 failed(tryUpdateDimOrFailure(*this, C, bShape.getDimSize(1), "b",
2077 "channels")))
2078 return failure();
2079 W = bShape.getDimSize(2);
2080 }
2081
2082 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2083 const auto outputType = cast<ShapedType>(getResult().getType());
2084 if (outputType.hasRank() &&
2085 failed(
2086 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2087 InFlightDiagnostic opError = emitOpError("expected output shape ");
2088 printShapeToDiagnostic(opError, outputType.getShape());
2089 opError << " to be compatible with expected output shape ";
2090 printShapeToDiagnostic(opError, expectedOutputShape);
2091 return opError;
2092 }
2093
2094 return success();
2095}
2096
2097LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
2098 MLIRContext *context, ::std::optional<Location> location,
2099 MatmulTBlockScaledOp::Adaptor adaptor,
2100 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2101 SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
2102
2103 const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
2104 if (aDataShape.hasRank()) {
2105 outShape[0] = aDataShape.getDimSize(0);
2106 outShape[1] = aDataShape.getDimSize(1);
2107 }
2108
2109 const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
2110 if (aScaleShape.hasRank()) {
2111 outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
2112 : outShape[0];
2113 outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
2114 : outShape[1];
2115 }
2116
2117 // If B batch size is 1, it is broadcast across A's batch size
2118 const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
2119 if (bDataShape.hasRank()) {
2120 const int64_t bDataBatchSize = bDataShape.getDimSize(0);
2121 if (bDataBatchSize != 1)
2122 outShape[0] =
2123 ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
2124 outShape[2] = bDataShape.getDimSize(1);
2125 }
2126
2127 const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
2128 if (bScaleShape.hasRank()) {
2129 const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
2130 if (bScaleBatchSize != 1)
2131 outShape[0] =
2132 ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
2133 outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
2134 : outShape[2];
2135 }
2136
2137 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2138 return success();
2139}
2140
2141LogicalResult MatmulTBlockScaledOp::verify() {
2142 // Verify same input data types
2143 const Type aDataType = getAData().getType();
2144 const Type bDataType = getBData().getType();
2145 if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
2146 "B_data")))
2147 return failure();
2148
2149 // Verify input shape compatibility
2150 int64_t N = ShapedType::kDynamic;
2151 int64_t D = ShapedType::kDynamic;
2152 int64_t H = ShapedType::kDynamic;
2153 int64_t W = ShapedType::kDynamic;
2154 int64_t C = ShapedType::kDynamic;
2155 int64_t multiplesOfC = ShapedType::kDynamic;
2156
2157 const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
2158 if (aDataShape.hasRank()) {
2159 N = aDataShape.getDimSize(0);
2160 H = aDataShape.getDimSize(1);
2161 C = aDataShape.getDimSize(2);
2162 }
2163
2164 const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
2165 if (aScaleShape.hasRank()) {
2166 if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
2167 "a_scale", "batch")) ||
2168 failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
2169 "a_scale", "height")))
2170 return failure();
2171 multiplesOfC = aScaleShape.getDimSize(2);
2172 }
2173
2174 const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
2175 if (bDataShape.hasRank()) {
2176 if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
2177 "b_data", "batch")) ||
2178 failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
2179 "b_data", "channels")))
2180 return failure();
2181 W = bDataShape.getDimSize(1);
2182 }
2183
2184 const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
2185 if (bScaleShape.hasRank()) {
2186 if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
2187 "b_scale", "batch")) ||
2188 failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
2189 "b_scale", "width")) ||
2190 failed(tryUpdateDimOrFailure(*this, multiplesOfC,
2191 bScaleShape.getDimSize(2), "b_scale",
2192 "C/block_size")))
2193 return failure();
2194 }
2195
2196 // Verify batch size is broadcast compatible
2197 if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
2198 return emitOpError("expect B matrix batch size to be broadcast compatible "
2199 "with A, got D=")
2200 << D << " vs N=" << N;
2201
2202 // Verify C is a multiple of block size
2203 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2204 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2205 return emitOpError("expect block size to be 32, got ") << blockSize;
2206 if (ShapedType::isStatic(C) && C % blockSize != 0)
2207 return emitOpError("expect C to be a multiple of block size, got C=")
2208 << C << ", block_size=" << blockSize;
2209
2210 // Verify multiplesOfC is C / block size
2211 if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
2212 multiplesOfC != C / blockSize)
2213 return emitOpError(
2214 "expect scale operands dimension 2 to equal C/block_size (")
2215 << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
2216
2217 // Verify output shape
2218 N = ShapedType::isDynamic(N) ? D : N;
2219 const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
2220 const auto outputType = cast<ShapedType>(getResult().getType());
2221 if (outputType.hasRank() &&
2222 failed(
2223 verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
2224 InFlightDiagnostic opError = emitOpError("expected output shape ");
2225 printShapeToDiagnostic(opError, outputType.getShape());
2226 opError << " to be compatible with expected output shape ";
2227 printShapeToDiagnostic(opError, expectedOutputShape);
2228 return opError;
2229 }
2230
2231 return success();
2232}
2233
2234LogicalResult tosa::PadOp::inferReturnTypeComponents(
2235 MLIRContext *context, ::std::optional<Location> location,
2236 PadOp::Adaptor adaptor,
2237 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2238 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2239 auto paddingRank =
2240 cast<tosa::shapeType>(adaptor.getPadding().getType()).getRank();
2241 SmallVector<int64_t> outputShape;
2242
2243 // If the input rank is unknown, we can infer the output rank using the
2244 // padding shape's rank divided by 2.
2245 if (!inputShape.hasRank()) {
2246 outputShape.resize(paddingRank / 2, ShapedType::kDynamic);
2247 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2248 return success();
2249 }
2250
2251 SmallVector<int64_t> paddingValues;
2252 // If the paddings value is not a constant, all dimensions must be dynamic.
2253 if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(),
2254 paddingValues)) {
2255 outputShape.resize(inputShape.getRank(), ShapedType::kDynamic);
2256 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2257 return success();
2258 }
2259
2260 outputShape.reserve(inputShape.getRank());
2261 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2262 if (inputShape.isDynamicDim(i)) {
2263 outputShape.push_back(ShapedType::kDynamic);
2264 continue;
2265 }
2266 auto padFront = paddingValues[i * 2];
2267 auto padBack = paddingValues[i * 2 + 1];
2268 if (padFront < 0 || padBack < 0) {
2269 // if either padding for dim i is -1, output dim is unknown
2270 outputShape.push_back(ShapedType::kDynamic);
2271 continue;
2272 }
2273
2274 outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack);
2275 }
2276
2277 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2278 return success();
2279}
2280
2281LogicalResult tosa::PadOp::verify() {
2282 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2283 /* outType = */ getOutput().getType())
2284 .failed()) {
2285 return failure();
2286 }
2287
2288 if (auto padConst = getPadConst()) {
2289 if (verifySameElementTypes(*this, /* inType = */ padConst.getType(),
2290 /* outType = */ getOutput().getType())
2291 .failed()) {
2292 return failure();
2293 }
2294 }
2295
2296 RankedTensorType inputType =
2297 llvm::dyn_cast<RankedTensorType>(getInput1().getType());
2298 RankedTensorType outputType =
2299 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
2300 if (!inputType || !outputType)
2301 return success();
2302
2303 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2304 "output")))
2305 return failure();
2306
2307 auto inputRank = inputType.getRank();
2308 DenseIntElementsAttr paddingAttr;
2309 if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
2310 return success();
2311
2312 auto paddingValues = paddingAttr.getValues<APInt>();
2313 if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
2314 return emitOpError() << "padding tensor must have " << inputRank
2315 << " * 2 = " << inputRank * 2 << " elements, but got "
2316 << paddingValues.size();
2317
2318 auto inputShape = inputType.getShape();
2319 auto outputShape = outputType.getShape();
2320
2321 for (int64_t i = 0; i < inputRank; ++i) {
2322 int64_t padStart = paddingValues[i * 2].getSExtValue();
2323 int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue();
2324
2325 if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) {
2326 return emitOpError()
2327 << "invalid padding values at dimension " << i
2328 << ": values must be non-negative or -1 for dynamic padding, got ["
2329 << padStart << ", " << padEnd << "]";
2330 }
2331
2332 // Skip shape verification for dynamic input/output
2333 if (inputShape[i] == ShapedType::kDynamic ||
2334 outputShape[i] == ShapedType::kDynamic)
2335 continue;
2336
2337 if (outputShape[i] != inputShape[i] + padStart + padEnd) {
2338 return emitOpError() << "mismatch in output shape at dimension " << i
2339 << ": expected " << inputShape[i] << " + "
2340 << padStart << " + " << padEnd << " = "
2341 << (inputShape[i] + padStart + padEnd)
2342 << ", but got " << outputShape[i];
2343 }
2344 }
2345
2346 return success();
2347}
2348
2349LogicalResult tosa::SliceOp::inferReturnTypeComponents(
2350 MLIRContext *context, ::std::optional<Location> location,
2351 SliceOp::Adaptor adaptor,
2352 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2353
2354 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2357
2358 if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) ||
2359 !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) {
2360 auto rank = cast<tosa::shapeType>(adaptor.getSize().getType()).getRank();
2361 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2362 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2363 return success();
2364 }
2365
2366 // if size[i] is -1, all remaining elements in dimension i are included
2367 // in the slice, similar to TF.
2368 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2369 // initialize outputShape to all unknown
2370 SmallVector<int64_t> outputShape(size.size(), ShapedType::kDynamic);
2371 if (inputShape.hasRank()) {
2372 for (size_t i = 0; i < size.size(); i++) {
2373 if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 &&
2374 (ShapedType::isDynamic(inputShape.getDimSize(i)) ||
2375 start[i] < inputShape.getDimSize(i))) {
2376 // size[i] is not 0 and not < -1, and start[i] is in valid range
2377 if (ShapedType::isDynamic(inputShape.getDimSize(i))) {
2378 // input shape has unknown dim[i] - only valid if size[i] > 0
2379 if (size[i] > 0) {
2380 outputShape[i] = size[i];
2381 }
2382 } else {
2383 // input shape has known dim[i]
2384 if (size[i] == -1) {
2385 outputShape[i] = inputShape.getDimSize(i) - start[i];
2386 } else if (start[i] + size[i] <= inputShape.getDimSize(i)) {
2387 // start[i] + size[i] is within bound of input shape's dim[i]
2388 outputShape[i] = size[i];
2389 }
2390 }
2391 }
2392 }
2393 } else {
2394 outputShape = convertToMlirShape(size);
2395 }
2396 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
2397 return success();
2398}
2399
2400LogicalResult tosa::SliceOp::verify() {
2401 const Value input = getInput1();
2402 const Value output = getOutput();
2403 if (verifySameElementTypes(*this, /* inType = */ input.getType(),
2404 /* outType = */ output.getType())
2405 .failed())
2406 return failure();
2407
2408 const Value start = getStart();
2409 const Value size = getSize();
2410 const ShapeAdaptor inputShape(input.getType());
2411 const ShapeAdaptor outputShape(output.getType());
2412
2413 if (inputShape.hasRank()) {
2414 const auto inputRank = inputShape.getRank();
2415 if (outputShape.hasRank() && inputRank != outputShape.getRank())
2416 return emitOpError(
2417 "expect input1 and output to have the same ranks, got ")
2418 << inputRank << " and " << outputShape.getRank();
2419
2420 const auto startShapeRank =
2421 llvm::cast<tosa::shapeType>(start.getType()).getRank();
2422 if (inputRank != startShapeRank)
2423 return emitOpError("length of start is not equal to rank of input shape");
2424
2425 const auto sizeShapeRank =
2426 llvm::cast<tosa::shapeType>(size.getType()).getRank();
2427 if (inputRank != sizeShapeRank)
2428 return emitOpError("length of size is not equal to rank of input shape");
2429 }
2430
2431 SmallVector<int64_t> startValues;
2432 tosa::getConstShapeValues(start.getDefiningOp(), startValues);
2433 if (startValues.size()) {
2434 if (llvm::any_of(startValues, [](const int64_t v) {
2435 return v < 0 && v != kInferableDimSize;
2436 }))
2437 return emitOpError("start values must be non-negative, got [")
2438 << startValues << "]";
2439 }
2440
2441 SmallVector<int64_t> sizeValues;
2442 if (!tosa::getConstShapeValues(size.getDefiningOp(), sizeValues))
2443 return success();
2444
2445 if (llvm::any_of(sizeValues, [](const int64_t v) {
2446 return v <= 0 && v != kInferableDimSize;
2447 }))
2448 return emitOpError("size values must be > 0, got [") << sizeValues << "]";
2449 if (outputShape.hasRank()) {
2450 SmallVector<int64_t> outputDims;
2451 outputShape.getDims(outputDims);
2452 const bool hasNoInferableDims = llvm::all_of(
2453 sizeValues, [](const int64_t v) { return v != kInferableDimSize; });
2454 if (hasNoInferableDims &&
2455 failed(verifyCompatibleShape(outputDims, sizeValues)))
2456 return emitOpError("expected output shape to match size values, got ")
2457 << output.getType() << " vs [" << sizeValues << "]";
2458 }
2459
2460 if (inputShape.hasRank() && startValues.size()) {
2461 SmallVector<int64_t> inputDims;
2462 inputShape.getDims(inputDims);
2463 for (const auto &[index, vals] :
2464 llvm::enumerate(llvm::zip_equal(startValues, sizeValues, inputDims))) {
2465 const auto &[start, size, inputDim] = vals;
2466 if (start == kInferableDimSize || size == kInferableDimSize ||
2467 ShapedType::isDynamic(inputDim))
2468 continue;
2469 if (start + size > inputDim)
2470 return emitOpError("start + size must be less than or equal to input "
2471 "dimension size, got start=")
2472 << start << ", size=" << size
2473 << " vs input dim size=" << inputDim << " at dimension "
2474 << index;
2475 }
2476 }
2477
2478 return success();
2479}
2480
2481LogicalResult tosa::MulOp::inferReturnTypeComponents(
2482 MLIRContext *context, ::std::optional<Location> location,
2483 ValueShapeRange operands, DictionaryAttr attributes, PropertyRef properties,
2484 RegionRange regions,
2485 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2486 // mul op's output shape only depend on input1 and input2, not on shift
2487 ValueShapeRange twoInputs = operands.drop_back();
2489 if (resolveBroadcastShape(twoInputs, outShape).failed()) {
2490 inferredReturnShapes.push_back(ShapedTypeComponents());
2491 } else {
2492 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
2493 }
2494 return success();
2495}
2496
2497LogicalResult tosa::MulOp::verify() {
2498 const Value output = getOutput();
2499 auto resElemType = getElementTypeOrSelf(output);
2500
2501 // Verify if the element type among operands and result match tosa
2502 // specification.
2503 if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
2504 IntegerType lhsIntType =
2505 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput1()));
2506 IntegerType rhsIntType =
2507 dyn_cast<IntegerType>(getElementTypeOrSelf(getInput2()));
2508 if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType)
2509 return emitOpError("requires the same element type for all operands");
2510
2511 // Though the spec requires the element type of result to be i32, a more
2512 // relaxed way is provided at dialect level for easier cooperating with
2513 // other dialects.
2514 if (lhsIntType.getWidth() > resIntType.getWidth())
2515 return emitOpError("invalid data type size for operands or result");
2516
2517 } else {
2518 // For other supported type, the spec requires requires the same element
2519 // type for all operands (excludes `shift` operand) and results.
2520 for (int i = 0; i < 2; ++i) {
2521 if (getElementTypeOrSelf(getOperand(i)) != resElemType)
2522 return emitOpError(
2523 "requires the same element type for all operands and results");
2524 }
2525
2526 // verify shift has value 0 for non-integer types
2527 ElementsAttr shiftElem;
2528 if (matchPattern(getShift(), m_Constant(&shiftElem))) {
2529 int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
2530 if (shift != 0) {
2531 return emitOpError() << "require shift to be 0 for float type";
2532 }
2533 }
2534 }
2535
2536 // Verify the op has same ranks for all main operands (excludes extra operands
2537 // such as shift of mul op, so this is the only difference with the built-in
2538 // `SameOperandsAndResultRank` trait) and results types, if known.
2539 TypeRange operandTypes = getOperandTypes();
2540 ShapedType aType = cast<ShapedType>(operandTypes[0]);
2541 ShapedType bType = cast<ShapedType>(operandTypes[1]);
2542
2543 const bool aHasRank = aType.hasRank();
2544 const bool bHasRank = bType.hasRank();
2545 if (aHasRank && bHasRank) {
2546 const int64_t aRank = aType.getRank();
2547 const int64_t bRank = bType.getRank();
2548 if (aRank != bRank)
2549 return emitOpError("a and b operands don't have matching ranks, got ")
2550 << aRank << " and " << bRank;
2551
2552 // check for broadcast compatible shapes
2553 SmallVector<int64_t> resultShape;
2555 aType.getShape(), bType.getShape(), resultShape))
2556 return emitOpError("a and b operands don't have broadcast-compatible "
2557 "shapes, got ")
2558 << aType << " and " << bType;
2559 }
2560
2561 ShapedType resultType = cast<ShapedType>(output.getType());
2562 if (!resultType.hasRank())
2563 return success();
2564
2565 const int64_t resultRank = resultType.getRank();
2566 if (aHasRank && resultRank != aType.getRank())
2567 return emitOpError("result type has different rank than a, got ")
2568 << resultRank << " vs " << aType.getRank();
2569 if (bHasRank && resultRank != bType.getRank())
2570 return emitOpError("result type has different rank than b, got ")
2571 << resultRank << " vs " << bType.getRank();
2572
2573 return success();
2574}
2575
2576LogicalResult tosa::TableOp::inferReturnTypeComponents(
2577 MLIRContext *context, ::std::optional<Location> location,
2578 TableOp::Adaptor adaptor,
2579 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2580 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2581
2582 if (!inputShape.hasRank()) {
2583 inferredReturnShapes.push_back(ShapedTypeComponents());
2584 return success();
2585 }
2586
2587 inferredReturnShapes.resize(1);
2588 inputShape.getDims(inferredReturnShapes[0]);
2589 return success();
2590}
2591
2592LogicalResult tosa::TableOp::verify() {
2593 const TensorType inputType = getInput1().getType();
2594 const TensorType outputType = getOutput().getType();
2595
2596 if (!inputType.hasRank() || !outputType.hasRank())
2597 return success();
2598
2599 if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2600 "result")))
2601 return failure();
2602
2603 auto inputDims = inputType.getShape();
2604 auto outputDims = outputType.getShape();
2605 for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) {
2606 int64_t dim = it.index();
2607 auto [inputDim, outputDim] = it.value();
2608 if (ShapedType::isStatic(outputDim) && outputDim != inputDim) {
2609 return emitOpError() << "dim(result, " << dim << ") = " << outputDim
2610 << " doesn't match dim(input, " << dim
2611 << ") = " << inputDim;
2612 }
2613 }
2614 return success();
2615}
2616
2617LogicalResult
2618tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
2619 // Multiples must be constants.
2620 DenseIntElementsAttr multiplesAttr;
2621 if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
2622 return failure();
2623 multiples =
2624 llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
2625 [](const APInt &val) { return val.getSExtValue(); });
2626 return success();
2627}
2628
2629LogicalResult tosa::TileOp::inferReturnTypeComponents(
2630 MLIRContext *context, ::std::optional<Location> location,
2631 TileOp::Adaptor adaptor,
2632 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2633 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2634 SmallVector<int64_t> multiples;
2635 if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
2636 multiples)) {
2637 auto rank =
2638 cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
2639 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2640 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2641 return success();
2642 }
2643 multiples = convertToMlirShape(multiples);
2644
2645 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2646 SmallVector<int64_t> outputShape;
2647 if (!inputShape.hasRank()) {
2648 outputShape.resize(multiples.size(), ShapedType::kDynamic);
2649 inferredReturnShapes.push_back(
2650 ShapedTypeComponents(outputShape, inputType));
2651 return success();
2652 }
2653 if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
2654 return failure();
2655
2656 // Any non dynamic dimension can be multiplied to a known size.
2657 outputShape.reserve(multiples.size());
2658 for (int i = 0, s = inputShape.getRank(); i < s; i++) {
2659 if (multiples[i] == ShapedType::kDynamic) {
2660 outputShape.push_back(ShapedType::kDynamic);
2661 } else {
2662 int64_t dim = inputShape.getDimSize(i);
2663 if (dim != ShapedType::kDynamic)
2664 dim *= multiples[i];
2665 outputShape.push_back(dim);
2666 }
2667 }
2668
2669 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
2670 return success();
2671}
2672
2673LogicalResult tosa::TileOp::verify() {
2674 if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(),
2675 /* outType = */ getOutput().getType())
2676 .failed()) {
2677 return failure();
2678 }
2679 ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
2680 ShapedType outputType = llvm::cast<ShapedType>(getType());
2681
2682 shapeType multiplesType =
2683 llvm::cast<tosa::shapeType>(getMultiples().getType());
2684
2685 auto multiplesRank = multiplesType.getRank();
2686
2687 if (inputType.hasRank()) {
2688 if (inputType.getRank() != multiplesRank)
2689 return emitOpError("expect 'multiples' to have rank ")
2690 << inputType.getRank() << " but got " << multiplesRank << ".";
2691 if (outputType.hasRank() &&
2692 failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
2693 "output")))
2694 return failure();
2695 } else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
2696 return emitOpError("expect 'multiples' array to have length ")
2697 << outputType.getRank() << " but got " << multiplesRank << ".";
2698
2699 SmallVector<int64_t> multiples;
2700 if (getConstantMultiples(multiples).succeeded() &&
2701 llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; }))
2702 return emitOpError(
2703 "expect element of 'multiples' to be positive integer or -1.");
2704
2705 return success();
2706}
2707
2708bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
2709 if (l.size() != r.size() || l.size() != 1)
2710 return false;
2711 return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
2712}
2713
2714LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
2715 MLIRContext *context, ::std::optional<Location> location,
2716 ReshapeOp::Adaptor adaptor,
2717 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2718 ShapeAdaptor inputShape(adaptor.getInput1().getType());
2719 Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
2720 llvm::SmallVector<int64_t> newShapeValue;
2721 if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(),
2722 newShapeValue)) {
2723 auto rank = cast<tosa::shapeType>(adaptor.getShape().getType()).getRank();
2724 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2725 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2726 return success();
2727 }
2728 newShapeValue = convertToMlirShape(newShapeValue);
2729
2730 // We cannot infer from the total number of elements so we must take the
2731 // shape attribute as exact.
2732 if (!inputShape.hasRank() || !inputShape.hasStaticShape()) {
2733 inferredReturnShapes.push_back(
2734 ShapedTypeComponents(newShapeValue, inputType));
2735 return success();
2736 }
2737
2738 // Determine the number of elements covered by the slice of all static
2739 // dimensions. This allows us to infer the length of the remaining dynamic
2740 // dimension.
2741 int64_t numElements = inputShape.getNumElements();
2742 int64_t staticMul = 1;
2743 for (auto val : newShapeValue) {
2744 if (ShapedType::isStatic(val)) {
2745 staticMul *= val;
2746 }
2747 }
2748
2749 // Determine the length of the dynamic dimension.
2750 for (auto &val : newShapeValue) {
2751 if (ShapedType::isDynamic(val))
2752 val = numElements / staticMul;
2753 }
2754
2755 inferredReturnShapes.push_back(
2756 ShapedTypeComponents(newShapeValue, inputType));
2757 return success();
2758}
2759
2760llvm::LogicalResult tosa::ReshapeOp::verify() {
2761 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
2762 /* outType = */ getOutput().getType())
2763 .failed()) {
2764 return failure();
2765 }
2766 TensorType inputType = getInput1().getType();
2767
2768 SmallVector<int64_t> shapeValues;
2769 if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) {
2770 // skip following checks if shape is not constant
2771 return mlir::success();
2772 }
2773
2774 int missingDims = llvm::count(shapeValues, kInferableDimSize);
2775 if (missingDims > 1)
2776 return emitOpError() << "expected at most one target dimension to be "
2778
2779 const auto outputType = dyn_cast<RankedTensorType>(getType());
2780 if (!outputType)
2781 return success();
2782
2783 if ((int64_t)shapeValues.size() != outputType.getRank())
2784 return emitOpError() << "new shape does not match result rank";
2785
2786 for (auto [newShapeDim, outputShapeDim] :
2787 zip(shapeValues, outputType.getShape())) {
2788 if (newShapeDim != kInferableDimSize &&
2789 newShapeDim != ShapedType::kDynamic &&
2790 outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim)
2791 return emitOpError() << "new shape is inconsistent with result shape";
2792
2793 if (newShapeDim != ShapedType::kDynamic && newShapeDim < kInferableDimSize)
2794 return emitOpError() << "new shape has invalid tensor dimension size "
2795 << newShapeDim;
2796 }
2797
2798 if (inputType.hasStaticShape()) {
2799 int64_t inputElementsNum = inputType.getNumElements();
2800 if (outputType.hasStaticShape()) {
2801 int64_t outputElementsNum = outputType.getNumElements();
2802 if (inputElementsNum != outputElementsNum) {
2803 return emitOpError() << "cannot reshape " << inputElementsNum
2804 << " elements into " << outputElementsNum;
2805 }
2806 }
2807
2808 int64_t newShapeElementsNum =
2809 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
2810 return (dim > 0) ? acc * dim : acc;
2811 });
2812 bool isStaticNewShape =
2813 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
2814 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
2815 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
2816 return emitOpError() << "cannot reshape " << inputElementsNum
2817 << " elements into " << newShapeElementsNum;
2818 }
2819 }
2820
2821 return mlir::success();
2822}
2823
2824bool tosa::ReshapeBlockScaledOp::isCompatibleReturnTypes(TypeRange l,
2825 TypeRange r) {
2826 if (l.size() != r.size() || l.size() < 1 || l.size() > 2)
2827 return false;
2828 bool ok = (getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]));
2829 if (l.size() == 2)
2830 ok = ok && (getElementTypeOrSelf(l[1]) == getElementTypeOrSelf(r[1]));
2831 return ok;
2832}
2833
2834LogicalResult tosa::ReshapeBlockScaledOp::inferReturnTypeComponents(
2835 MLIRContext *context, ::std::optional<Location> location,
2836 ReshapeBlockScaledOp::Adaptor adaptor,
2837 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
2838
2839 const auto numInputs = adaptor.getInput().size();
2840 ShapeAdaptor inputShape(adaptor.getInput()[0].getType());
2841 Type inputType = getElementTypeOrSelf(adaptor.getInput()[0].getType());
2842 llvm::SmallVector<int64_t> newShapeValue;
2843 const auto newShape = adaptor.getNewValueShape();
2844 if (!tosa::getConstShapeValues(newShape.getDefiningOp(), newShapeValue)) {
2845 auto rank = cast<tosa::shapeType>(newShape.getType()).getRank();
2846 SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
2847 inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
2848 if (numInputs == 2)
2849 inferredReturnShapes.push_back(ShapedTypeComponents(
2850 fallback, getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2851 return success();
2852 }
2853
2854 const uint32_t blockSize =
2855 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
2856
2857 llvm::SmallVector<int64_t> newScaleShapeValue;
2858 if (numInputs == 2) {
2859 newScaleShapeValue.assign(newShapeValue.begin(), newShapeValue.end());
2860 if (ShapedType::isStatic(newScaleShapeValue.back()))
2861 newScaleShapeValue.back() /= blockSize;
2862 }
2863
2864 inferredReturnShapes.push_back(
2865 ShapedTypeComponents(newShapeValue, inputType));
2866 if (numInputs == 2) {
2867 // Fix up scale shape - with special case for last dimension
2868 for (size_t idx = 0; idx < newShapeValue.size(); idx++) {
2869 if (ShapedType::isDynamic(newScaleShapeValue[idx])) {
2870 newScaleShapeValue[idx] = newShapeValue[idx];
2871 if (idx == (newShapeValue.size() - 1))
2872 newScaleShapeValue[idx] /= blockSize;
2873 }
2874 }
2875
2876 inferredReturnShapes.push_back(ShapedTypeComponents(
2877 newScaleShapeValue,
2878 getElementTypeOrSelf(adaptor.getInput()[1].getType())));
2879 }
2880 return success();
2881}
2882
2883llvm::LogicalResult tosa::ReshapeBlockScaledOp::verify() {
2884 const Operation::operand_range inputList = getInput();
2885 const Operation::result_range outputList = getResults();
2886
2887 if (inputList.size() == 0)
2888 return emitOpError("requires at least one input");
2889
2890 if (inputList.size() > 2)
2891 return emitOpError("requires at most two inputs");
2892
2893 if (inputList.size() != outputList.size())
2894 return emitOpError("requires number of results to match inputs");
2895
2896 if (verifySameElementTypes(*this, /* inType = */ inputList[0].getType(),
2897 /* outType = */ outputList[0].getType())
2898 .failed()) {
2899 return failure();
2900 }
2901
2902 const auto inputType = llvm::cast<ShapedType>(inputList[0].getType());
2903 if (!inputType.hasRank())
2904 return success();
2905 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
2906
2907 if (inputList.size() == 2) {
2908 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
2909 return emitOpError("expect block size to be 32, got ") << blockSize;
2910 if (llvm::any_of(inputList, [](Value v) {
2911 const auto input = cast<ShapedType>(v.getType());
2912 return input.hasRank() && input.getRank() == 0;
2913 }))
2914 return emitOpError(
2915 "requires all input shapes have a rank greater than 0");
2916 if (llvm::any_of(outputList, [](Value v) {
2917 const auto output = cast<ShapedType>(v.getType());
2918 return output.hasRank() && output.getRank() == 0;
2919 }))
2920 return emitOpError(
2921 "requires all result shapes have a rank greater than 0");
2922
2923 if (verifySameElementTypes(*this, /* inType = */ inputList[1].getType(),
2924 /* outType = */ outputList[1].getType())
2925 .failed()) {
2926 return failure();
2927 }
2928
2929 const auto inputScaleType = llvm::cast<ShapedType>(inputList[1].getType());
2930 if (inputScaleType.hasRank()) {
2931 if (inputType.getRank() != inputScaleType.getRank())
2932 return emitOpError("input shapes do not have same rank");
2933
2934 // Check all but the last dimension that the input shape dimensions match
2935 for (auto dimIdx = 0; dimIdx < inputType.getRank() - 1; dimIdx++) {
2936 const int64_t inputValueDim = inputType.getDimSize(dimIdx);
2937 const int64_t inputScaleDim = inputScaleType.getShape()[dimIdx];
2938 if (ShapedType::isStatic(inputValueDim) &&
2939 ShapedType::isStatic(inputScaleDim) &&
2940 inputValueDim != inputScaleDim)
2941 return emitOpError("input shapes for data and scale do not match on "
2942 "dimension ")
2943 << dimIdx;
2944 }
2945
2946 // Verify last dimension of input is a multiple of block size
2947 const int64_t lastValueDim =
2948 inputType.getDimSize(inputType.getRank() - 1);
2949 if (ShapedType::isStatic(lastValueDim)) {
2950 if (lastValueDim % blockSize != 0)
2951 return emitOpError("expect last dimension of input_data (")
2952 << lastValueDim << ") to be divisible by block_size ("
2953 << blockSize << ")";
2954
2955 const int64_t lastScaleDim =
2956 inputScaleType.getDimSize(inputScaleType.getRank() - 1);
2957 // Verify last dimension of scale is lastValueDim / block size
2958 if (ShapedType::isStatic(lastScaleDim) &&
2959 lastScaleDim != lastValueDim / blockSize)
2960 return emitOpError("expect last dimension of scale_data (")
2961 << lastScaleDim << ") to be " << lastValueDim << "/"
2962 << blockSize;
2963 }
2964 }
2965 } else {
2966 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_1))
2967 return emitOpError("expect block size to be 1, got ") << blockSize;
2968 }
2969
2970 // Get the new value shape dimension values
2971 SmallVector<int64_t> shapeValues;
2972 if (!tosa::getConstShapeValues(getNewValueShape().getDefiningOp(),
2973 shapeValues)) {
2974 // skip following checks if shape is not constant
2975 return mlir::success();
2976 }
2977
2978 if (inputList.size() == 2) {
2979 if (static_cast<int64_t>(shapeValues.size()) == 0)
2980 return emitOpError("requires new shape to have a rank greater than 0");
2981
2982 const int64_t lastShapeDim = shapeValues.back();
2983 if (ShapedType::isStatic(lastShapeDim) && lastShapeDim % blockSize != 0)
2984 return emitOpError("expect last dimension of new shape (")
2985 << lastShapeDim << ") to be divisible by block_size (" << blockSize
2986 << ")";
2987 }
2988
2989 const auto outputType = llvm::cast<ShapedType>(outputList[0].getType());
2990 if (!outputType.hasRank())
2991 return success();
2992
2993 if (static_cast<int64_t>(shapeValues.size()) != outputType.getRank())
2994 return emitOpError() << "result does not match new shape rank";
2995
2996 for (auto [newShapeDim, outputShapeDim] :
2997 zip(shapeValues, outputType.getShape())) {
2998 if (ShapedType::isStatic(newShapeDim) &&
2999 ShapedType::isStatic(outputShapeDim) && newShapeDim != outputShapeDim)
3000 return emitOpError() << "result shape is inconsistent with new shape";
3001 }
3002
3003 if (outputList.size() == 2) {
3004 // Set up scale shape from new shape given
3005 SmallVector<int64_t> scaleShapeValues(shapeValues.begin(),
3006 shapeValues.end());
3007 scaleShapeValues.back() /= blockSize;
3008
3009 const auto outputScaleType =
3010 llvm::cast<ShapedType>(outputList[1].getType());
3011 if (outputScaleType.hasRank()) {
3012 if ((int64_t)scaleShapeValues.size() != outputScaleType.getRank())
3013 return emitOpError() << "result scale does not match new shape rank";
3014
3015 for (auto [newScaleShapeDim, outputScaleShapeDim] :
3016 zip(scaleShapeValues, outputScaleType.getShape())) {
3017 if (ShapedType::isStatic(newScaleShapeDim) &&
3018 ShapedType::isStatic(outputScaleShapeDim) &&
3019 newScaleShapeDim != outputScaleShapeDim)
3020 return emitOpError()
3021 << "result scale shape is inconsistent with new shape";
3022 }
3023 }
3024 }
3025
3026 if (inputType.hasStaticShape()) {
3027 int64_t inputElementsNum = inputType.getNumElements();
3028 if (outputType.hasStaticShape()) {
3029 int64_t outputElementsNum = outputType.getNumElements();
3030 if (inputElementsNum != outputElementsNum) {
3031 return emitOpError() << "cannot reshape " << inputElementsNum
3032 << " elements into " << outputElementsNum;
3033 }
3034 }
3035
3036 int64_t newShapeElementsNum =
3037 llvm::accumulate(shapeValues, int64_t(1), [](int64_t acc, int64_t dim) {
3038 return (dim > 0) ? acc * dim : acc;
3039 });
3040 bool isStaticNewShape =
3041 llvm::all_of(shapeValues, [](int64_t s) { return s > 0; });
3042 if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) ||
3043 (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) {
3044 return emitOpError() << "cannot reshape " << inputElementsNum
3045 << " elements into " << newShapeElementsNum;
3046 }
3047 }
3048
3049 return mlir::success();
3050}
3051
3052// return failure if val is not a constant
3053// set zp to -1 if val is non-zero float or val is not integer nor float
3054// otherwise set zp to val's constant value
3055static FailureOr<int64_t> getZeroPoint(Value val, bool signExtend) {
3056 ElementsAttr zpAttr;
3057 if (!matchPattern(val, m_Constant(&zpAttr))) {
3058 return failure();
3059 }
3060
3061 Type zpElemType = zpAttr.getElementType();
3062
3063 if (llvm::isa<FloatType>(zpElemType)) {
3064 if (zpAttr.getValues<APFloat>()[0].isZero()) {
3065 return 0;
3066 }
3067 // return non-zero value to trigger error check
3068 return -1;
3069 }
3070
3071 if (llvm::isa<IntegerType>(zpElemType)) {
3072 if (signExtend)
3073 return zpAttr.getValues<APInt>()[0].getSExtValue();
3074 return zpAttr.getValues<APInt>()[0].getZExtValue();
3075 }
3076
3077 // return non-zero value to trigger error check
3078 return -1;
3079}
3080
3081static FailureOr<int64_t> getConstantScalarIntValue(Value val) {
3082 ElementsAttr attr;
3083 if (!matchPattern(val, m_Constant(&attr)))
3084 return failure();
3085
3086 if (!llvm::isa<IntegerType>(attr.getElementType()) ||
3087 attr.getNumElements() != 1)
3088 return failure();
3089
3090 return attr.getValues<APInt>()[0].getSExtValue();
3091}
3092
3093template <typename T>
3094static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
3095 const std::string &operand) {
3096 Type zpElemType = getElementTypeOrSelf(val);
3097
3098 if (!zpElemType.isInteger(8) && zp != 0) {
3099 // convert operand to lower case for error message
3100 std::string lower = operand;
3101 llvm::transform(lower, lower.begin(), ::tolower);
3102 return op.emitOpError()
3103 << lower << " zero point must be zero for non-int8 integer types";
3104 }
3105
3106 return success();
3107}
3108
3109static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
3110 const int64_t &zp,
3111 const std::string &operand) {
3112 bool isInputZp = (operand == "Input");
3113
3114 bool tensorUnsigned =
3115 isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
3116 StringRef tensorName = isInputZp ? "input" : "output";
3117
3118 Type zpElemType = getElementTypeOrSelf(zpVal);
3119
3120 if (zp != 0) {
3121 if (!zpElemType.isInteger(8) &&
3122 !(zpElemType.isInteger(16) && tensorUnsigned)) {
3123 return op.emitOpError()
3124 << "expect " << tensorName << "_zp of 0, got " << zp;
3125 }
3126 if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
3127 return op.emitOpError() << "expect " << tensorName
3128 << "_zp of 0 or 32768 for unsigned int16 "
3129 << tensorName << ", got " << zp;
3130 }
3131 }
3132
3133 return success();
3134}
3135
3136#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \
3137 FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
3138 return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \
3139 } \
3140 LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \
3141 return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \
3142 }
3143
3144ZERO_POINT_HELPER(Conv2DOp, Input, true)
3145ZERO_POINT_HELPER(Conv2DOp, Weight, true)
3146ZERO_POINT_HELPER(Conv3DOp, Input, true)
3147ZERO_POINT_HELPER(Conv3DOp, Weight, true)
3148ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true)
3149ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true)
3150ZERO_POINT_HELPER(TransposeConv2DOp, Input, true)
3151ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true)
3152ZERO_POINT_HELPER(AvgPool2dOp, Input, true)
3153ZERO_POINT_HELPER(AvgPool2dOp, Output, true)
3154ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Input, true)
3155ZERO_POINT_HELPER(AvgPool2dAdaptiveOp, Output, true)
3156ZERO_POINT_HELPER(MatMulOp, A, true)
3157ZERO_POINT_HELPER(MatMulOp, B, true)
3158ZERO_POINT_HELPER(NegateOp, Input1, true)
3159ZERO_POINT_HELPER(NegateOp, Output, true)
3160ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned())
3161ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned())
3162#undef ZERO_POINT_HELPER
3163
3164LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
3165 MLIRContext *context, ::std::optional<Location> location,
3166 TransposeOp::Adaptor adaptor,
3167 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3168 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3169
3170 // If input rank and permutation length is unknown, the output rank is
3171 // unknown.
3172 if (!inputShape.hasRank()) {
3173 inferredReturnShapes.push_back(ShapedTypeComponents());
3174 return success();
3175 }
3176
3177 const auto inputRank = inputShape.getRank();
3178
3179 // This would imply the number of permutations does not match the rank of
3180 // the input which is illegal.
3181 if (adaptor.getPerms().size() != static_cast<size_t>(inputRank)) {
3182 return failure();
3183 }
3184
3185 SmallVector<int64_t> outputShape;
3186 // Rank-0 means no permutations matter.
3187 if (inputRank == 0) {
3188 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3189 return success();
3190 }
3191
3192 // Check whether the input dimensions are all the same.
3193 bool allTheSame = true;
3194 for (int i = 1, s = inputRank; i < s; i++) {
3195 if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) {
3196 allTheSame = false;
3197 break;
3198 }
3199 }
3200
3201 // If all of the input dimensions are the same we don't care about the
3202 // permutation.
3203 if (allTheSame) {
3204 outputShape.resize(inputRank, inputShape.getDimSize(0));
3205 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3206 return success();
3207 }
3208
3209 outputShape.resize(inputRank, ShapedType::kDynamic);
3210
3211 // Constant permutation values must be within the input rank.
3212 if (llvm::any_of(adaptor.getPerms(),
3213 [inputRank](const auto i) { return i >= inputRank; }))
3214 return failure();
3215
3216 outputShape.reserve(inputRank);
3217 for (int i = 0, s = inputRank; i < s; i++) {
3218 outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]);
3219 }
3220
3221 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3222 return success();
3223}
3224
3225LogicalResult tosa::TransposeOp::verify() {
3226 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
3227 /* outType = */ getOutput().getType())
3228 .failed()) {
3229 return failure();
3230 }
3231
3232 const ShapeAdaptor inputShape(getInput1().getType());
3233 const ShapeAdaptor outputShape(getOutput().getType());
3234
3235 const llvm::ArrayRef<int32_t> constantPerms = getPerms();
3236
3237 if (inputShape.hasRank() &&
3238 constantPerms.size() != static_cast<size_t>(inputShape.getRank()))
3239 return emitOpError() << "expected perms attribute to have size "
3240 << inputShape.getRank()
3241 << " (input rank) but got size "
3242 << constantPerms.size();
3243
3244 if (inputShape.hasRank() && outputShape.hasRank() &&
3245 inputShape.getRank() != outputShape.getRank())
3246 return emitOpError()
3247 << "expected input tensor rank to equal result tensor rank";
3248
3249 if (outputShape.hasRank() &&
3250 constantPerms.size() != static_cast<size_t>(outputShape.getRank()))
3251 return emitOpError() << "expected perms attribute to have size "
3252 << outputShape.getRank()
3253 << " (output rank) but got size "
3254 << constantPerms.size();
3255
3256 if (!llvm::all_of(constantPerms,
3257 [&constantPerms](int32_t s) {
3258 return s >= 0 &&
3259 static_cast<size_t>(s) < constantPerms.size();
3260 }) ||
3261 !isPermutationVector(llvm::map_to_vector(
3262 constantPerms, [](int32_t v) -> int64_t { return v; })))
3263 return emitOpError() << "expected valid permutation indices";
3264
3265 // ERROR_IF(tensor_size(shape1) != tensor_size(shape))
3266 if (inputShape.hasStaticShape() && outputShape.hasStaticShape() &&
3267 inputShape.getNumElements() != outputShape.getNumElements())
3268 return emitOpError() << "expected input1 and output to have same numbers "
3269 "of elements, got "
3270 << inputShape.getNumElements() << " and "
3271 << outputShape.getNumElements();
3272
3273 // Verify that the types of the input and output tensors are properly
3274 // permuted.
3275 if (inputShape.hasRank() && outputShape.hasRank()) {
3276 for (auto i = 0; i < outputShape.getRank(); i++) {
3277 if (inputShape.isDynamicDim(constantPerms[i]) ||
3278 outputShape.isDynamicDim(i))
3279 continue;
3280
3281 if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i))
3282 return emitOpError()
3283 << "expected output tensor dim " << i << " to match "
3284 << "input dim " << constantPerms[i] << " with value of "
3285 << inputShape.getDimSize(constantPerms[i]);
3286 }
3287 }
3288
3289 return success();
3290}
3291
3292LogicalResult TransposeOp::reifyResultShapes(
3293 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
3294
3295 const llvm::ArrayRef<int32_t> transposePerms = getPerms();
3296
3297 Value input = getInput1();
3298 auto inputType = cast<TensorType>(input.getType());
3299
3300 SmallVector<OpFoldResult> returnedDims(inputType.getRank());
3301 for (auto dim : transposePerms) {
3302 int32_t dimInInput = transposePerms[dim];
3303 if (inputType.isDynamicDim(dimInInput))
3304 returnedDims[dim] =
3305 tensor::DimOp::create(builder, getLoc(), input, dimInInput)
3306 .getResult();
3307 else
3308 returnedDims[dim] =
3309 builder.getIndexAttr(inputType.getDimSize(dimInInput));
3310 }
3311
3312 reifiedReturnShapes.emplace_back(std::move(returnedDims));
3313 return success();
3314}
3315
3316LogicalResult tosa::GatherOp::inferReturnTypeComponents(
3317 MLIRContext *context, ::std::optional<Location> location,
3318 GatherOp::Adaptor adaptor,
3319 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3320 llvm::SmallVector<int64_t> outputShape;
3321 outputShape.resize(3, ShapedType::kDynamic);
3322
3323 ShapeAdaptor valuesShape(adaptor.getValues().getType());
3324 if (valuesShape.hasRank()) {
3325 outputShape[0] = valuesShape.getDimSize(0);
3326 outputShape[2] = valuesShape.getDimSize(2);
3327 }
3328
3329 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3330 if (indicesShape.hasRank()) {
3331 if (outputShape[0] == ShapedType::kDynamic)
3332 outputShape[0] = indicesShape.getDimSize(0);
3333 if (outputShape[1] == ShapedType::kDynamic)
3334 outputShape[1] = indicesShape.getDimSize(1);
3335 }
3336
3337 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3338 return success();
3339}
3340
3341LogicalResult tosa::RowGatherBlockScaledOp::inferReturnTypeComponents(
3342 MLIRContext *context, ::std::optional<Location> location,
3343 RowGatherBlockScaledOp::Adaptor adaptor,
3344 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3345 const auto values = adaptor.getValues();
3346 if (values.empty())
3347 return failure();
3348
3349 SmallVector<int64_t> dataShape(3, ShapedType::kDynamic);
3350 const ShapeAdaptor valuesShape(values.front().getType());
3351 if (valuesShape.hasRank()) {
3352 dataShape[0] = valuesShape.getDimSize(0);
3353 dataShape[2] = valuesShape.getDimSize(2);
3354 }
3355
3356 const ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3357 if (indicesShape.hasRank()) {
3358 if (dataShape[0] == ShapedType::kDynamic)
3359 dataShape[0] = indicesShape.getDimSize(0);
3360
3361 if (auto rowCount = getConstantScalarIntValue(adaptor.getRowCount());
3362 succeeded(rowCount) && rowCount.value() > 0) {
3363 const int64_t indicesW = indicesShape.getDimSize(1);
3364 if (ShapedType::isStatic(indicesW))
3365 dataShape[1] = indicesW * rowCount.value();
3366 }
3367 }
3368
3369 inferredReturnShapes.push_back(ShapedTypeComponents(dataShape));
3370 if (values.size() == 1)
3371 return success();
3372
3373 SmallVector<int64_t> scaleShape = dataShape;
3374 const uint32_t blockSize =
3375 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
3376 if (ShapedType::isStatic(dataShape[2]))
3377 scaleShape[2] = dataShape[2] / blockSize;
3378
3379 inferredReturnShapes.push_back(ShapedTypeComponents(scaleShape));
3380 return success();
3381}
3382
3383LogicalResult tosa::GatherOp::verify() {
3384 if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
3385 /* outType = */ getOutput().getType())
3386 .failed()) {
3387 return failure();
3388 }
3389
3390 const ShapeAdaptor valuesShape(getValues().getType());
3391 const ShapeAdaptor indicesShape(getIndices().getType());
3392 const ShapeAdaptor outputShape(getOutput().getType());
3393
3394 int64_t n = ShapedType::kDynamic;
3395 int64_t w = ShapedType::kDynamic;
3396 int64_t c = ShapedType::kDynamic;
3397
3398 if (valuesShape.hasRank()) {
3399 n = valuesShape.getDimSize(0);
3400 c = valuesShape.getDimSize(2);
3401 }
3402 if (indicesShape.hasRank()) {
3403 const int64_t indicesN = indicesShape.getDimSize(0);
3404 w = indicesShape.getDimSize(1);
3405 if (n == ShapedType::kDynamic)
3406 n = indicesN;
3407 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3408 return emitOpError() << "requires indices dimension 0 to have size " << n
3409 << ", got " << indicesN;
3410 }
3411 if (outputShape.hasRank()) {
3412 const int64_t outputN = outputShape.getDimSize(0);
3413 const int64_t outputW = outputShape.getDimSize(1);
3414 const int64_t outputC = outputShape.getDimSize(2);
3415 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3416 n != outputN)
3417 return emitOpError() << "requires output dimension 0 to have size " << n
3418 << ", got " << outputN;
3419
3420 if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
3421 w != outputW)
3422 return emitOpError() << "requires output dimension 1 to have size " << w
3423 << ", got " << outputW;
3424 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3425 c != outputC)
3426 return emitOpError() << "requires output dimension 2 to have size " << c
3427 << ", got " << outputC;
3428 }
3429 return success();
3430}
3431
3432LogicalResult tosa::RowGatherBlockScaledOp::verify() {
3433 const OperandRange values = getValues();
3434 const ResultRange output = getOutput();
3435 if (values.empty() || values.size() > 2)
3436 return emitOpError()
3437 << "expects values tensor list length to be 1 or 2, got "
3438 << values.size();
3439 if (output.size() != values.size())
3440 return emitOpError()
3441 << "expects output tensor list length to match values tensor list "
3442 "length, got "
3443 << output.size() << " results for " << values.size()
3444 << " input tensors";
3445
3446 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
3447 if (values.size() == 1 && blockSize != 1)
3448 return emitOpError()
3449 << "requires block_size to be BLOCK_SIZE_1 when values tensor list "
3450 "length is 1";
3451 if (values.size() == 2 && blockSize == 1)
3452 return emitOpError()
3453 << "requires block_size to not be BLOCK_SIZE_1 when values tensor "
3454 "list length is 2";
3455
3456 if (failed(verifySameElementTypes(*this, values[0].getType(),
3457 output[0].getType(), "values[0]",
3458 "output[0]")))
3459 return failure();
3460 if (values.size() == 2 && failed(verifySameElementTypes(
3461 *this, values[1].getType(), output[1].getType(),
3462 "values[1]", "output[1]")))
3463 return failure();
3464
3465 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3466 succeeded(rowCount) && rowCount.value() <= 0)
3467 return emitOpError() << "requires row_count to be > 0, got "
3468 << rowCount.value();
3469
3470 int64_t n = ShapedType::kDynamic;
3471 int64_t k = ShapedType::kDynamic;
3472 int64_t c = ShapedType::kDynamic;
3473 int64_t w = ShapedType::kDynamic;
3474 int64_t multiplesOfC = ShapedType::kDynamic;
3475
3476 const ShapeAdaptor valuesDataShape(values[0].getType());
3477 if (valuesDataShape.hasRank()) {
3478 n = valuesDataShape.getDimSize(0);
3479 k = valuesDataShape.getDimSize(1);
3480 c = valuesDataShape.getDimSize(2);
3481 }
3482
3483 if (ShapedType::isStatic(c) && c % blockSize != 0)
3484 return emitOpError() << "expects channels of values[0] (" << c
3485 << ") to be divisible by block_size (" << blockSize
3486 << ")";
3487
3488 const ShapeAdaptor indicesShape(getIndices().getType());
3489 if (indicesShape.hasRank()) {
3490 if (failed(tryUpdateDimOrFailure(*this, n, indicesShape.getDimSize(0),
3491 "indices", "batch")))
3492 return failure();
3493 w = indicesShape.getDimSize(1);
3494 }
3495
3496 const ShapeAdaptor outputDataShape(output[0].getType());
3497 if (outputDataShape.hasRank()) {
3498 if (failed(tryUpdateDimOrFailure(*this, n, outputDataShape.getDimSize(0),
3499 "output[0]", "batch")) ||
3500 failed(tryUpdateDimOrFailure(*this, c, outputDataShape.getDimSize(2),
3501 "output[0]", "channels")))
3502 return failure();
3503
3504 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3505 succeeded(rowCount) && rowCount.value() > 0 &&
3506 ShapedType::isStatic(w)) {
3507 const int64_t expectedOutputRows = w * rowCount.value();
3508 if (ShapedType::isStatic(outputDataShape.getDimSize(1)) &&
3509 outputDataShape.getDimSize(1) != expectedOutputRows)
3510 return emitOpError() << "requires output[0] dimension 1 to have size "
3511 << expectedOutputRows << ", got "
3512 << outputDataShape.getDimSize(1);
3513 }
3514 }
3515
3516 if (values.size() == 2) {
3517 const ShapeAdaptor valuesScaleShape(values[1].getType());
3518 if (valuesScaleShape.hasRank()) {
3519 if (failed(tryUpdateDimOrFailure(*this, n, valuesScaleShape.getDimSize(0),
3520 "values[1]", "batch")) ||
3521 failed(tryUpdateDimOrFailure(*this, k, valuesScaleShape.getDimSize(1),
3522 "values[1]", "rows")))
3523 return failure();
3524 multiplesOfC = valuesScaleShape.getDimSize(2);
3525 }
3526
3527 const ShapeAdaptor outputScaleShape(output[1].getType());
3528 if (outputScaleShape.hasRank()) {
3529 if (failed(tryUpdateDimOrFailure(*this, n, outputScaleShape.getDimSize(0),
3530 "output[1]", "batch")))
3531 return failure();
3532
3533 if (auto rowCount = getConstantScalarIntValue(getRowCount());
3534 succeeded(rowCount) && rowCount.value() > 0 &&
3535 ShapedType::isStatic(w)) {
3536 const int64_t expectedOutputRows = w * rowCount.value();
3537 if (ShapedType::isStatic(outputScaleShape.getDimSize(1)) &&
3538 outputScaleShape.getDimSize(1) != expectedOutputRows)
3539 return emitOpError() << "requires output[1] dimension 1 to have size "
3540 << expectedOutputRows << ", got "
3541 << outputScaleShape.getDimSize(1);
3542 }
3543
3544 if (ShapedType::isDynamic(multiplesOfC))
3545 multiplesOfC = outputScaleShape.getDimSize(2);
3546 else if (ShapedType::isStatic(outputScaleShape.getDimSize(2)) &&
3547 multiplesOfC != outputScaleShape.getDimSize(2))
3548 return emitOpError()
3549 << "expected channels of output[1] to match size "
3550 << multiplesOfC << ", got " << outputScaleShape.getDimSize(2);
3551 }
3552
3553 if (ShapedType::isStatic(c) && ShapedType::isStatic(multiplesOfC) &&
3554 multiplesOfC != c / blockSize)
3555 return emitOpError()
3556 << "expects channels of scale tensors to equal C/block_size (" << c
3557 << "/" << blockSize << "), got " << multiplesOfC;
3558 }
3559
3560 return success();
3561}
3562
3563LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
3564 MLIRContext *context, ::std::optional<Location> location,
3565 ResizeOp::Adaptor adaptor,
3566 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3567 llvm::SmallVector<int64_t, 4> outputShape;
3568 outputShape.resize(4, ShapedType::kDynamic);
3569
3570 ShapeAdaptor inputShape(adaptor.getInput().getType());
3571 if (!inputShape.hasRank())
3572 return failure();
3573
3574 outputShape[0] = inputShape.getDimSize(0);
3575 outputShape[3] = inputShape.getDimSize(3);
3576 int64_t inputHeight = inputShape.getDimSize(1);
3577 int64_t inputWidth = inputShape.getDimSize(2);
3578
3579 if ((inputHeight == ShapedType::kDynamic) ||
3580 (inputWidth == ShapedType::kDynamic))
3581 return failure();
3582
3583 SmallVector<int64_t> scaleInt, offsetInt, borderInt;
3584 if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(),
3585 scaleInt) ||
3586 !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(),
3587 offsetInt) ||
3588 !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(),
3589 borderInt)) {
3590 return failure();
3591 }
3592
3593 // Compute the output shape based on attributes: scale, offset, and border.
3594 const int64_t outputHeight =
3595 (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) /
3596 scaleInt[1]) +
3597 1;
3598
3599 const int64_t outputWidth =
3600 (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) /
3601 scaleInt[3]) +
3602 1;
3603
3604 if (outputHeight < 0 || outputWidth < 0) {
3605 return emitOptionalError(
3606 location,
3607 "calculated output height and width must be non-negative, "
3608 "got height = ",
3609 outputHeight, ", width = ", outputWidth);
3610 }
3611
3612 outputShape[1] = outputHeight;
3613 outputShape[2] = outputWidth;
3614 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3615 return success();
3616}
3617
3618LogicalResult tosa::ResizeOp::verify() {
3619 const Value input = getInput();
3620 const Value output = getOutput();
3621 const RankedTensorType inputType =
3622 llvm::dyn_cast<RankedTensorType>(input.getType());
3623 const RankedTensorType outputType =
3624 llvm::dyn_cast<RankedTensorType>(output.getType());
3625
3626 SmallVector<int64_t> scaleValues;
3627 SmallVector<int64_t> offsetValues;
3628 SmallVector<int64_t> borderValues;
3629 if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) ||
3630 !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) ||
3631 !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) {
3632 // Skip following checks if shape is not constant
3633 return success();
3634 }
3635
3636 if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
3637 return emitOpError("expect all scale values to be > 0, got ")
3638 << scaleValues;
3639
3640 const int64_t scaleYN = scaleValues[0];
3641 const int64_t scaleYD = scaleValues[1];
3642 const int64_t scaleXN = scaleValues[2];
3643 const int64_t scaleXD = scaleValues[3];
3644
3645 const int64_t offsetY = offsetValues[0];
3646 const int64_t offsetX = offsetValues[1];
3647
3648 const int64_t borderY = borderValues[0];
3649 const int64_t borderX = borderValues[1];
3650
3651 if (!inputType)
3652 return success();
3653 if (!outputType)
3654 return success();
3655
3656 const int64_t oh = outputType.getDimSize(1);
3657 const int64_t ow = outputType.getDimSize(2);
3658 const int64_t ih = inputType.getDimSize(1);
3659 const int64_t iw = inputType.getDimSize(2);
3660
3661 // Don't check with input height that could be broadcast (ih != 1)
3662 // since Linalg, a consumer of TOSA, expects broadcasting support
3663 // in resize to be available. Taking the cautious approach for now,
3664 // we can consider removing support for broadcasting later.
3665 if (ih != ShapedType::kDynamic && ih != 1) {
3666 const std::optional<int64_t> calculatedOutHeightMinusOne =
3667 idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
3668 if (!calculatedOutHeightMinusOne.has_value())
3669 return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
3670 "border_y ")
3671 << "to be wholly divisible by scale_y_d, got ((" << ih
3672 << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
3673 << ") / " << scaleYD;
3674 const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
3675 if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
3676 return emitOpError("calculated output height did not match expected: ")
3677 << "calculated=" << calculatedOutHeight << ", expected=" << oh;
3678 }
3679
3680 // Don't check with input width that could be broadcast (iw != 1)
3681 // since Linalg, a consumer of TOSA, expects broadcasting support
3682 // in resize to be available. Taking the cautious approach for now,
3683 // we can consider removing support for broadcasting later.
3684 if (iw != ShapedType::kDynamic && iw != 1) {
3685 const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
3686 const std::optional<int64_t> calculatedOutWidthMinusOne =
3687 idivCheck(scaledInWidth, scaleXD);
3688 if (!calculatedOutWidthMinusOne.has_value())
3689 return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
3690 "border_x ")
3691 << "to be wholly divisible by scale_x_d, got ((" << iw
3692 << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
3693 << ") / " << scaleXD;
3694 const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
3695 if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
3696 return emitOpError("calculated output width did not match expected: ")
3697 << "calculated=" << calculatedOutWidth << ", expected=" << ow;
3698 }
3699
3700 return success();
3701}
3702
3703LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
3704 MLIRContext *context, ::std::optional<Location> location,
3705 ScatterOp::Adaptor adaptor,
3706 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3707 llvm::SmallVector<int64_t> outputShape;
3708 outputShape.resize(3, ShapedType::kDynamic);
3709
3710 ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType());
3711 if (valuesInShape.hasRank()) {
3712 outputShape[0] = valuesInShape.getDimSize(0);
3713 outputShape[1] = valuesInShape.getDimSize(1);
3714 outputShape[2] = valuesInShape.getDimSize(2);
3715 }
3716
3717 ShapeAdaptor indicesShape(adaptor.getIndices().getType());
3718 if (indicesShape.hasRank()) {
3719 if (outputShape[0] == ShapedType::kDynamic)
3720 outputShape[0] = indicesShape.getDimSize(0);
3721 }
3722
3723 ShapeAdaptor inputShape(adaptor.getInput().getType());
3724 if (inputShape.hasRank()) {
3725 if (outputShape[0] == ShapedType::kDynamic)
3726 outputShape[0] = inputShape.getDimSize(0);
3727 if (outputShape[2] == ShapedType::kDynamic)
3728 outputShape[2] = inputShape.getDimSize(2);
3729 }
3730
3731 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
3732 return success();
3733}
3734
3735LogicalResult tosa::ScatterOp::verify() {
3736 if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(),
3737 /* outType = */ getValuesOut().getType())
3738 .failed() ||
3739 verifySameElementTypes(*this, /* inType = */ getInput().getType(),
3740 /* outType = */ getValuesOut().getType())
3741 .failed()) {
3742 return failure();
3743 }
3744
3745 const ShapeAdaptor valuesInShape(getValuesIn().getType());
3746 const ShapeAdaptor indicesShape(getIndices().getType());
3747 const ShapeAdaptor inputShape(getInput().getType());
3748 const ShapeAdaptor outputShape(getValuesOut().getType());
3749
3750 int64_t n = ShapedType::kDynamic;
3751 int64_t k = ShapedType::kDynamic;
3752 int64_t w = ShapedType::kDynamic;
3753 int64_t c = ShapedType::kDynamic;
3754 if (valuesInShape.hasRank()) {
3755 n = valuesInShape.getDimSize(0);
3756 k = valuesInShape.getDimSize(1);
3757 c = valuesInShape.getDimSize(2);
3758 }
3759 if (indicesShape.hasRank()) {
3760 const int64_t indicesN = indicesShape.getDimSize(0);
3761 w = indicesShape.getDimSize(1);
3762 if (n == ShapedType::kDynamic)
3763 n = indicesN;
3764 else if (indicesN != ShapedType::kDynamic && n != indicesN)
3765 return emitOpError() << "requires indices dimension 0 to have size " << n
3766 << ", got " << indicesN;
3767 }
3768 if (inputShape.hasRank()) {
3769 const int64_t inputN = inputShape.getDimSize(0);
3770 const int64_t inputW = inputShape.getDimSize(1);
3771 const int64_t inputC = inputShape.getDimSize(2);
3772 if (n == ShapedType::kDynamic)
3773 n = inputN;
3774 else if (inputN != ShapedType::kDynamic && n != inputN)
3775 return emitOpError() << "requires input dimension 0 to have size " << n
3776 << ", got " << inputN;
3777 if (w == ShapedType::kDynamic)
3778 w = inputW;
3779 else if (inputW != ShapedType::kDynamic && w != inputW)
3780 return emitOpError() << "requires input dimension 1 to have size " << w
3781 << ", got " << inputW;
3782
3783 if (c == ShapedType::kDynamic)
3784 c = inputC;
3785 else if (inputC != ShapedType::kDynamic && c != inputC)
3786 return emitOpError() << "requires input dimension 2 to have size " << c
3787 << ", got " << inputC;
3788 }
3789 if (outputShape.hasRank()) {
3790 const int64_t outputN = outputShape.getDimSize(0);
3791 const int64_t outputK = outputShape.getDimSize(1);
3792 const int64_t outputC = outputShape.getDimSize(2);
3793 if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
3794 n != outputN)
3795 return emitOpError() << "requires values_out dimension 0 to have size "
3796 << n << ", got " << outputN;
3797 if (k == ShapedType::kDynamic)
3798 k = outputK;
3799 else if (outputK != ShapedType::kDynamic && k != outputK)
3800 return emitOpError() << "requires values_out dimension 1 to have size "
3801 << k << ", got " << outputK;
3802 if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
3803 c != outputC)
3804 return emitOpError() << "requires values_out dimension 2 to have size "
3805 << c << ", got " << outputC;
3806 }
3807 if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
3808 return emitOpError() << "requires dimensions K >= W, got K=" << k
3809 << " and W=" << w;
3810
3811 return success();
3812}
3813
3814static LogicalResult ReduceInferReturnTypes(
3815 ShapeAdaptor operandShape, Type inputType, IntegerAttr axis,
3816 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3817 int64_t axisVal = axis.getValue().getSExtValue();
3818 if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) {
3819 inferredReturnShapes.push_back(ShapedTypeComponents(inputType));
3820 return success();
3821 }
3822
3823 SmallVector<int64_t> outputShape;
3824 operandShape.getDims(outputShape);
3825 outputShape[axisVal] = 1;
3826 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
3827 return success();
3828}
3829
3830#define COMPATIBLE_RETURN_TYPES(OP) \
3831 bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \
3832 if (l.size() != r.size() || l.size() != 1) \
3833 return false; \
3834 if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \
3835 return false; \
3836 return succeeded(verifyCompatibleShape(l[0], r[0])); \
3837 }
3838
3839#define REDUCE_SHAPE_INFER(OP) \
3840 LogicalResult OP::inferReturnTypeComponents( \
3841 MLIRContext *context, ::std::optional<Location> location, \
3842 OP::Adaptor adaptor, \
3843 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3844 Type inputType = \
3845 llvm::cast<TensorType>(adaptor.getInput().getType()).getElementType(); \
3846 ShapeAdaptor inputShape(adaptor.getInput().getType()); \
3847 const Properties &prop = adaptor.getProperties(); \
3848 return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \
3849 inferredReturnShapes); \
3850 } \
3851 COMPATIBLE_RETURN_TYPES(OP)
3852
3853REDUCE_SHAPE_INFER(tosa::ReduceAllOp)
3854REDUCE_SHAPE_INFER(tosa::ReduceAnyOp)
3855REDUCE_SHAPE_INFER(tosa::ReduceMaxOp)
3856REDUCE_SHAPE_INFER(tosa::ReduceMinOp)
3857REDUCE_SHAPE_INFER(tosa::ReduceProductOp)
3858REDUCE_SHAPE_INFER(tosa::ReduceSumOp)
3859#undef REDUCE_SHAPE_INFER
3860COMPATIBLE_RETURN_TYPES(tosa::ConcatOp)
3861#undef COMPATIBLE_RETURN_TYPES
3862
3863template <typename T>
3864static LogicalResult verifyReduceOp(T op) {
3865 // All TOSA reduce Ops have input, output and axis.
3866 TensorType inputType = op.getInput().getType();
3867 TensorType outputType = op.getOutput().getType();
3868 int32_t reduceAxis = op.getAxis();
3869
3870 if (reduceAxis < 0) {
3871 op.emitOpError("reduce axis must not be negative");
3872 return failure();
3873 }
3874 if (inputType.hasRank()) {
3875 int64_t inputRank = inputType.getRank();
3876 // We allow for a special case where the input/output shape has rank 0 and
3877 // axis is also 0.
3878 if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) {
3879 op.emitOpError("expect input tensor rank (")
3880 << inputRank << ") to be larger than reduce axis (" << reduceAxis
3881 << ")";
3882 return failure();
3883 }
3884 }
3885 if (outputType.hasRank()) {
3886 int64_t outputRank = outputType.getRank();
3887 if (inputType.hasRank() && outputRank != inputType.getRank()) {
3888 op.emitOpError(
3889 "expect output tensor rank to be equal to input tensor rank");
3890 return failure();
3891 }
3892 if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) {
3893 op.emitOpError("expect output tensor rank (")
3894 << outputRank << ") to be larger than reduce axis (" << reduceAxis
3895 << ")";
3896 return failure();
3897 }
3898 // We can only verify the reduced dimension size to be 1 if this is not
3899 // the special case of output rank == 0.
3900 if (outputRank != 0) {
3901 auto outputShape = outputType.getShape();
3902 if (!outputType.isDynamicDim(reduceAxis) &&
3903 outputShape[reduceAxis] != 1) {
3904 op.emitOpError("expect reduced dimension size to be 1, got ")
3905 << outputShape[reduceAxis];
3906 return failure();
3907 }
3908 }
3909 }
3910 return success();
3911}
3912
3913LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); }
3914LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); }
3915LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); }
3916LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); }
3917LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); }
3918LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); }
3919
3920static LogicalResult NAryInferReturnTypes(
3921 const ValueShapeRange &operands,
3922 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3924 if (resolveBroadcastShape(operands, outShape).failed()) {
3925 inferredReturnShapes.push_back(ShapedTypeComponents());
3926 } else {
3927 inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
3928 }
3929 return success();
3930}
3931
3932#define NARY_SHAPE_INFER(OP) \
3933 LogicalResult OP::inferReturnTypeComponents( \
3934 MLIRContext *context, ::std::optional<Location> location, \
3935 ValueShapeRange operands, DictionaryAttr attributes, \
3936 PropertyRef properties, RegionRange regions, \
3937 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
3938 return NAryInferReturnTypes(operands, inferredReturnShapes); \
3939 }
3940
3941NARY_SHAPE_INFER(tosa::AbsOp)
3942NARY_SHAPE_INFER(tosa::AddOp)
3943NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
3944NARY_SHAPE_INFER(tosa::BitwiseAndOp)
3945NARY_SHAPE_INFER(tosa::BitwiseOrOp)
3946NARY_SHAPE_INFER(tosa::BitwiseXorOp)
3947NARY_SHAPE_INFER(tosa::BitwiseNotOp)
3948NARY_SHAPE_INFER(tosa::CastOp)
3949NARY_SHAPE_INFER(tosa::CeilOp)
3950NARY_SHAPE_INFER(tosa::ClampOp)
3951NARY_SHAPE_INFER(tosa::ClzOp)
3952NARY_SHAPE_INFER(tosa::CosOp)
3953NARY_SHAPE_INFER(tosa::ExpOp)
3954NARY_SHAPE_INFER(tosa::FloorOp)
3955NARY_SHAPE_INFER(tosa::GreaterEqualOp)
3956NARY_SHAPE_INFER(tosa::GreaterOp)
3957NARY_SHAPE_INFER(tosa::IdentityOp)
3958NARY_SHAPE_INFER(tosa::IntDivOp)
3959NARY_SHAPE_INFER(tosa::LogOp)
3960NARY_SHAPE_INFER(tosa::LogicalAndOp)
3961NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
3962NARY_SHAPE_INFER(tosa::LogicalNotOp)
3963NARY_SHAPE_INFER(tosa::LogicalOrOp)
3964NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
3965NARY_SHAPE_INFER(tosa::LogicalXorOp)
3966NARY_SHAPE_INFER(tosa::MaximumOp)
3967NARY_SHAPE_INFER(tosa::MinimumOp)
3968NARY_SHAPE_INFER(tosa::PowOp)
3969NARY_SHAPE_INFER(tosa::ReciprocalOp)
3970NARY_SHAPE_INFER(tosa::ReverseOp)
3971NARY_SHAPE_INFER(tosa::RsqrtOp)
3972NARY_SHAPE_INFER(tosa::SinOp)
3973NARY_SHAPE_INFER(tosa::SelectOp)
3974NARY_SHAPE_INFER(tosa::SubOp)
3975NARY_SHAPE_INFER(tosa::TanhOp)
3976NARY_SHAPE_INFER(tosa::ErfOp)
3977NARY_SHAPE_INFER(tosa::SigmoidOp)
3978#undef PRED_SHAPE_INFER
3979
3980LogicalResult tosa::NegateOp::inferReturnTypeComponents(
3981 MLIRContext *context, ::std::optional<Location> location,
3982 NegateOp::Adaptor adaptor,
3983 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
3984 ShapeAdaptor inputShape(adaptor.getInput1().getType());
3985 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
3986 return success();
3987}
3988
3989LogicalResult tosa::NegateOp::verify() {
3990 // Verify same element type
3991 const Type input1Type = getInput1().getType();
3992 const Type outputType = getOutput().getType();
3993 if (verifySameElementTypes(*this, input1Type, outputType).failed())
3994 return failure();
3995
3996 // Verify same shape
3997 const SmallVector<Type, 2> types = {input1Type, outputType};
3998 if (failed(verifyCompatibleShapes(types)))
3999 return emitOpError() << "requires the same shape for input1 and output";
4000
4001 const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType());
4002 const Type input1ZpEType =
4003 getStorageElementTypeOrSelf(getInput1Zp().getType());
4004 if (input1EType != input1ZpEType) {
4005 return emitOpError("expect both input1 and its zero point are the same "
4006 "element type, got ")
4007 << input1EType << " and " << input1ZpEType;
4008 }
4009 const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType());
4010 const Type outputZpEType =
4011 getStorageElementTypeOrSelf(getOutputZp().getType());
4012 if (outputEType != outputZpEType) {
4013 return emitOpError("expect both output and its zero point are the same "
4014 "element type, got ")
4015 << outputEType << " and " << outputZpEType;
4016 }
4017
4018 FailureOr<int64_t> maybeIZp = getInput1ZeroPoint();
4019 if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed())
4020 return failure();
4021
4022 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4023 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4024 return failure();
4025
4026 return success();
4027}
4028
4029static LogicalResult poolingInferReturnTypes(
4030 ShapeAdaptor inputShape, ArrayRef<int64_t> kernel, ArrayRef<int64_t> stride,
4032 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4033 llvm::SmallVector<int64_t> outputShape;
4034 outputShape.resize(4, ShapedType::kDynamic);
4035
4036 // We only know the rank if the input type is unranked.
4037 if (!inputShape.hasRank()) {
4038 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4039 return success();
4040 }
4041
4042 // Batch and number of channels are identical for pooling layer.
4043 outputShape[0] = inputShape.getDimSize(0);
4044 outputShape[3] = inputShape.getDimSize(3);
4045
4046 int64_t height = inputShape.getDimSize(1);
4047 int64_t width = inputShape.getDimSize(2);
4048
4049 if (ShapedType::isStatic(height)) {
4050 int64_t padded = height + pad[0] + pad[1] - kernel[0];
4051 outputShape[1] = padded / stride[0] + 1;
4052 }
4053
4054 if (ShapedType::isStatic(width)) {
4055 int64_t padded = width + pad[2] + pad[3] - kernel[1];
4056 outputShape[2] = padded / stride[1] + 1;
4057 }
4058
4059 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4060 return success();
4061}
4062
4063template <typename AdaptorT>
4065
4067protected:
4068 static void updateIfDynamic(int64_t &current, int64_t candidate) {
4069 if (ShapedType::isDynamic(current))
4070 current = candidate;
4071 }
4072};
4073
4074template <>
4075class ConvInferShapeAdaptor<Conv2DOp::Adaptor>
4076 : public ConvInferShapeAdaptorBase {
4077public:
4078 explicit ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
4079 : adaptor(adaptor) {}
4080
4082 SmallVectorImpl<int64_t> &inputSpatial) {
4083 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4084 if (!inputShape.hasRank())
4085 return;
4086
4087 const int64_t outputBatch = inputShape.getDimSize(0);
4088 const int64_t inputHeight = inputShape.getDimSize(1);
4089 const int64_t inputWidth = inputShape.getDimSize(2);
4090
4091 outputShape[0] = outputBatch;
4092 inputSpatial[0] = inputHeight;
4093 inputSpatial[1] = inputWidth;
4094 }
4095
4097 SmallVectorImpl<int64_t> &weightSpatial) {
4098 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4099 if (!weightShape.hasRank())
4100 return;
4101
4102 const int64_t outputChannels = weightShape.getDimSize(0);
4103 const int64_t kernelHeight = weightShape.getDimSize(1);
4104 const int64_t kernelWidth = weightShape.getDimSize(2);
4105
4106 outputShape[3] = outputChannels;
4107 weightSpatial[0] = kernelHeight;
4108 weightSpatial[1] = kernelWidth;
4109 }
4110
4111 int64_t getNumSpatialDims() const { return 2; }
4112 int64_t getOutputRank() const { return 4; }
4113
4115 SmallVector<int64_t> &strideValues,
4116 SmallVector<int64_t> &dilationValues) {
4117 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4118 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4119 dilationValues.assign(adaptor.getDilation().begin(),
4120 adaptor.getDilation().end());
4121 return success();
4122 }
4123
4124private:
4125 Conv2DOp::Adaptor adaptor;
4126};
4127
4128template <>
4129class ConvInferShapeAdaptor<Conv2DBlockScaledOp::Adaptor>
4130 : public ConvInferShapeAdaptorBase {
4131public:
4132 explicit ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
4133 : adaptor(adaptor) {}
4134
4136 SmallVectorImpl<int64_t> &inputSpatial) {
4137 const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
4138 if (inputDataShape.hasRank()) {
4139 const int64_t outputBatch = inputDataShape.getDimSize(0);
4140 const int64_t inputHeight = inputDataShape.getDimSize(1);
4141 const int64_t inputWidth = inputDataShape.getDimSize(2);
4142
4143 outputShape[0] = outputBatch;
4144 inputSpatial[0] = inputHeight;
4145 inputSpatial[1] = inputWidth;
4146 }
4147
4148 const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
4149 if (!inputScaleShape.hasRank())
4150 return;
4151
4152 const int64_t scaleBatch = inputScaleShape.getDimSize(0);
4153 const int64_t scaleHeight = inputScaleShape.getDimSize(1);
4154 const int64_t scaleWidth = inputScaleShape.getDimSize(2);
4155
4156 updateIfDynamic(outputShape[0], scaleBatch);
4157 updateIfDynamic(inputSpatial[0], scaleHeight);
4158 updateIfDynamic(inputSpatial[1], scaleWidth);
4159 }
4160
4162 SmallVectorImpl<int64_t> &weightSpatial) {
4163 const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
4164 if (weightDataShape.hasRank()) {
4165 const int64_t outputChannels = weightDataShape.getDimSize(0);
4166 const int64_t kernelHeight = weightDataShape.getDimSize(1);
4167 const int64_t kernelWidth = weightDataShape.getDimSize(2);
4168
4169 outputShape[3] = outputChannels;
4170 weightSpatial[0] = kernelHeight;
4171 weightSpatial[1] = kernelWidth;
4172 }
4173
4174 const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
4175 if (!weightScaleShape.hasRank())
4176 return;
4177
4178 const int64_t scaleOutputChannels = weightScaleShape.getDimSize(0);
4179 const int64_t scaleKernelHeight = weightScaleShape.getDimSize(1);
4180 const int64_t scaleKernelWidth = weightScaleShape.getDimSize(2);
4181
4182 updateIfDynamic(outputShape[3], scaleOutputChannels);
4183 updateIfDynamic(weightSpatial[0], scaleKernelHeight);
4184 updateIfDynamic(weightSpatial[1], scaleKernelWidth);
4185 }
4186
4187 int64_t getNumSpatialDims() const { return 2; }
4188 int64_t getOutputRank() const { return 4; }
4189
4191 SmallVector<int64_t> &strideValues,
4192 SmallVector<int64_t> &dilationValues) {
4193 if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(),
4194 padValues) ||
4195 !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4196 strideValues) ||
4197 !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
4198 dilationValues))
4199 return failure();
4200 return success();
4201 }
4202
4203private:
4204 Conv2DBlockScaledOp::Adaptor adaptor;
4205};
4206
4207template <>
4208class ConvInferShapeAdaptor<Conv3DOp::Adaptor>
4209 : public ConvInferShapeAdaptorBase {
4210public:
4211 explicit ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
4212 : adaptor(adaptor) {}
4213
4215 SmallVectorImpl<int64_t> &inputSpatial) {
4216 const ShapeAdaptor inputShape(adaptor.getInput().getType());
4217 if (!inputShape.hasRank())
4218 return;
4219
4220 const int64_t outputBatch = inputShape.getDimSize(0);
4221 const int64_t inputDepth = inputShape.getDimSize(1);
4222 const int64_t inputHeight = inputShape.getDimSize(2);
4223 const int64_t inputWidth = inputShape.getDimSize(3);
4224
4225 outputShape[0] = outputBatch;
4226 inputSpatial[0] = inputDepth;
4227 inputSpatial[1] = inputHeight;
4228 inputSpatial[2] = inputWidth;
4229 }
4230
4232 SmallVectorImpl<int64_t> &weightSpatial) {
4233 const ShapeAdaptor weightShape(adaptor.getWeight().getType());
4234 if (!weightShape.hasRank())
4235 return;
4236
4237 const int64_t outputChannels = weightShape.getDimSize(0);
4238 const int64_t kernelDepth = weightShape.getDimSize(1);
4239 const int64_t kernelHeight = weightShape.getDimSize(2);
4240 const int64_t kernelWidth = weightShape.getDimSize(3);
4241
4242 outputShape[4] = outputChannels;
4243 weightSpatial[0] = kernelDepth;
4244 weightSpatial[1] = kernelHeight;
4245 weightSpatial[2] = kernelWidth;
4246 }
4247
4248 int64_t getNumSpatialDims() const { return 3; }
4249 int64_t getOutputRank() const { return 5; }
4250
4252 SmallVector<int64_t> &strideValues,
4253 SmallVector<int64_t> &dilationValues) {
4254 padValues.assign(adaptor.getPad().begin(), adaptor.getPad().end());
4255 strideValues.assign(adaptor.getStride().begin(), adaptor.getStride().end());
4256 dilationValues.assign(adaptor.getDilation().begin(),
4257 adaptor.getDilation().end());
4258 return success();
4259 }
4260
4261private:
4262 Conv3DOp::Adaptor adaptor;
4263};
4264
4265template <typename AdaptorT>
4267 AdaptorT adaptor,
4268 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4269 ConvInferShapeAdaptor<AdaptorT> convShapeAdaptor(adaptor);
4270 llvm::SmallVector<int64_t> outputShape(convShapeAdaptor.getOutputRank(),
4271 ShapedType::kDynamic);
4272 llvm::SmallVector<int64_t> inputSpatial(convShapeAdaptor.getNumSpatialDims(),
4273 ShapedType::kDynamic);
4274 llvm::SmallVector<int64_t> weightSpatial(convShapeAdaptor.getNumSpatialDims(),
4275 ShapedType::kDynamic);
4276
4277 convShapeAdaptor.inferInputShape(outputShape, inputSpatial);
4278 convShapeAdaptor.inferWeightShape(outputShape, weightSpatial);
4279
4280 const ShapeAdaptor biasShape = adaptor.getBias().getType();
4281 if (biasShape.hasRank()) {
4282 const int64_t biasSize = biasShape.getDimSize(0);
4283 if (biasSize != 1) {
4284 const size_t outputChannelDim = convShapeAdaptor.getOutputRank() - 1;
4285 outputShape[outputChannelDim] =
4286 ShapedType::isDynamic(outputShape[outputChannelDim])
4287 ? biasSize
4288 : outputShape[outputChannelDim];
4289 }
4290 }
4291
4292 SmallVector<int64_t> padValues;
4293 SmallVector<int64_t> strideValues;
4294 SmallVector<int64_t> dilationValues;
4295 if (failed(convShapeAdaptor.getSpatialParameters(padValues, strideValues,
4296 dilationValues))) {
4297 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4298 return success();
4299 }
4300
4301 for (int64_t dim = 0; dim < convShapeAdaptor.getNumSpatialDims(); ++dim) {
4302 if (!ShapedType::isStatic(inputSpatial[dim]) ||
4303 !ShapedType::isStatic(weightSpatial[dim]))
4304 continue;
4305 const int64_t inputSize =
4306 inputSpatial[dim] + padValues[2 * dim] + padValues[2 * dim + 1];
4307 const int64_t filterSize =
4308 (weightSpatial[dim] - 1) * dilationValues[dim] + 1;
4309 const int64_t unstridedResult = inputSize - filterSize + 1;
4310 outputShape[dim + 1] = (unstridedResult - 1) / strideValues[dim] + 1;
4311 }
4312
4313 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4314 return success();
4315}
4316
4317LogicalResult Conv2DOp::inferReturnTypeComponents(
4318 MLIRContext *context, ::std::optional<Location> location,
4319 Conv2DOp::Adaptor adaptor,
4320 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4321 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4322}
4323
4324LogicalResult Conv2DOp::verify() {
4325 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4326 verifyConvOpErrorIf(*this).failed())
4327 return failure();
4328 return success();
4329}
4330
4331LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
4332 MLIRContext *context, ::std::optional<Location> location,
4333 Conv2DBlockScaledOp::Adaptor adaptor,
4334 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4335 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4336}
4337
4338LogicalResult Conv2DBlockScaledOp::verify() {
4339 if (failed(verifySameElementTypes(*this, getInputData().getType(),
4340 getWeightData().getType(), "input_data",
4341 "weight_data")) ||
4342 failed(verifySameElementTypes(*this, getInputScale().getType(),
4343 getWeightScale().getType(), "input_scale",
4344 "weight_scale")) ||
4345 failed(verifySameElementTypes(*this, getBias().getType(),
4346 getOutput().getType(), "bias", "output")))
4347 return failure();
4348
4349 // Verify input shape compatibility
4350 int64_t N = ShapedType::kDynamic;
4351 int64_t IH = ShapedType::kDynamic;
4352 int64_t IW = ShapedType::kDynamic;
4353 int64_t IC = ShapedType::kDynamic;
4354 int64_t multiplesOfIC = ShapedType::kDynamic;
4355 int64_t OC = ShapedType::kDynamic;
4356 int64_t KH = ShapedType::kDynamic;
4357 int64_t KW = ShapedType::kDynamic;
4358
4359 const ShapeAdaptor inputDataShape(getInputData().getType());
4360 if (inputDataShape.hasRank()) {
4361 N = inputDataShape.getDimSize(0);
4362 IH = inputDataShape.getDimSize(1);
4363 IW = inputDataShape.getDimSize(2);
4364 IC = inputDataShape.getDimSize(3);
4365 }
4366
4367 const ShapeAdaptor inputScaleShape(getInputScale().getType());
4368 if (inputScaleShape.hasRank()) {
4369 if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
4370 "input_scale", "batch size")) ||
4371 failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
4372 "input_scale", "input height")) ||
4373 failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
4374 "input_scale", "input width")))
4375 return failure();
4376 multiplesOfIC = inputScaleShape.getDimSize(3);
4377 }
4378
4379 const ShapeAdaptor weightDataShape(getWeightData().getType());
4380 if (weightDataShape.hasRank()) {
4381 OC = weightDataShape.getDimSize(0);
4382 KH = weightDataShape.getDimSize(1);
4383 KW = weightDataShape.getDimSize(2);
4384 if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
4385 "weight_data", "input channels")))
4386 return failure();
4387 }
4388
4389 const ShapeAdaptor weightScaleShape(getWeightScale().getType());
4390 if (weightScaleShape.hasRank()) {
4391 if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
4392 "weight_scale", "output channels")) ||
4393 failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
4394 "weight_scale", "kernel height")) ||
4395 failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
4396 "weight_scale", "kernel width")) ||
4397 failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
4398 weightScaleShape.getDimSize(3),
4399 "weight_scale", "input channel blocks")))
4400 return failure();
4401 }
4402
4403 const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
4404 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4405 return emitOpError("expect block size to be 32, got ") << blockSize;
4406 // Verify IC is a multiple of block size
4407 if (ShapedType::isStatic(IC) && IC % blockSize != 0)
4408 return emitOpError("expect IC to be a multiple of block size, got IC=")
4409 << IC << ", block_size=" << blockSize;
4410
4411 // Verify multiplesOfIC is IC / block size
4412 if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
4413 multiplesOfIC != IC / blockSize)
4414 return emitOpError(
4415 "expect scale operands dimension 2 to equal IC/block_size (")
4416 << IC << "/" << blockSize << ")"
4417 << ", got " << multiplesOfIC;
4418
4419 // Verify pad/stride/dilation values
4420 SmallVector<int64_t> padValues;
4421 if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
4422 if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
4423 return emitOpError("expect all padding values to be >= 0, got ")
4424 << padValues;
4425 }
4426
4427 SmallVector<int64_t> strideValues;
4428 if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
4429 if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
4430 return emitOpError("expect all stride values to be >= 1, got ")
4431 << strideValues;
4432 }
4433
4434 SmallVector<int64_t> dilationValues;
4435 if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
4436 dilationValues)) {
4437 if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
4438 return emitOpError("expect all dilation values to be >= 1, got ")
4439 << dilationValues;
4440 }
4441
4442 // Verify output shape compatibility
4443 const ShapeAdaptor outputShape(getOutput().getType());
4444 if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
4445 outputShape.hasRank()) {
4446 if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
4447 padValues[0], padValues[1], strideValues[0],
4448 dilationValues[0], "height", "y", "top",
4449 "bottom")) ||
4450 failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
4451 padValues[2], padValues[3], strideValues[1],
4452 dilationValues[1], "width", "x", "left",
4453 "right")))
4454 return failure();
4455 }
4456
4457 // Verify bias
4458 const ShapeAdaptor biasShape(getBias().getType());
4459 if (biasShape.hasRank() && outputShape.hasRank()) {
4460 const int64_t biasChannels = biasShape.getDimSize(0);
4461 const int64_t outputChannels =
4462 outputShape.getDimSize(outputShape.getRank() - 1);
4463 if (biasChannels == ShapedType::kDynamic ||
4464 outputChannels == ShapedType::kDynamic)
4465 // Skip following checks if biasChannels or outputChannels is dynamic dim
4466 return success();
4467
4468 if (biasChannels != outputChannels && biasChannels != 1)
4469 return emitOpError(
4470 "bias channels expected to be equal to output channels (")
4471 << outputChannels << ") or 1, got " << biasChannels;
4472 }
4473
4474 return success();
4475}
4476
4477LogicalResult Conv3DOp::inferReturnTypeComponents(
4478 MLIRContext *context, ::std::optional<Location> location,
4479 Conv3DOp::Adaptor adaptor,
4480 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4481 return inferConvReturnTypeComponents(adaptor, inferredReturnShapes);
4482}
4483
4484LogicalResult Conv3DOp::verify() {
4485 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4486 verifyConvOpErrorIf(*this).failed())
4487 return failure();
4488 return success();
4489}
4490
4491LogicalResult AvgPool2dOp::inferReturnTypeComponents(
4492 MLIRContext *context, ::std::optional<Location> location,
4493 AvgPool2dOp::Adaptor adaptor,
4494 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4495 ShapeAdaptor inputShape(adaptor.getInput().getType());
4496 const Properties &prop = adaptor.getProperties();
4497 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4498 inferredReturnShapes);
4499}
4500
4501LogicalResult AvgPool2dAdaptiveOp::inferReturnTypeComponents(
4502 MLIRContext *context, ::std::optional<Location> location,
4503 AvgPool2dAdaptiveOp::Adaptor adaptor,
4504 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4505 ShapeAdaptor inputShape(adaptor.getInput().getType());
4506
4507 llvm::SmallVector<int64_t> kernelValues;
4508 llvm::SmallVector<int64_t> strideValues;
4509 llvm::SmallVector<int64_t> padValues;
4510 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4511 kernelValues) &&
4512 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4513 strideValues) &&
4514 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4515 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4516 padValues, inferredReturnShapes);
4517 }
4518
4519 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4520 if (inputShape.hasRank()) {
4521 // Keep N & C as pooling only changes H & W.
4522 outputShape[0] = inputShape.getDimSize(0);
4523 outputShape[3] = inputShape.getDimSize(3);
4524 }
4525
4526 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4527 return success();
4528}
4529
4530LogicalResult MaxPool2dOp::inferReturnTypeComponents(
4531 MLIRContext *context, ::std::optional<Location> location,
4532 MaxPool2dOp::Adaptor adaptor,
4533 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4534 ShapeAdaptor inputShape(adaptor.getInput().getType());
4535 const Properties &prop = adaptor.getProperties();
4536 return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad,
4537 inferredReturnShapes);
4538}
4539
4540LogicalResult MaxPool2dAdaptiveOp::inferReturnTypeComponents(
4541 MLIRContext *context, ::std::optional<Location> location,
4542 MaxPool2dAdaptiveOp::Adaptor adaptor,
4543 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4544 ShapeAdaptor inputShape(adaptor.getInput().getType());
4545
4546 llvm::SmallVector<int64_t> kernelValues;
4547 llvm::SmallVector<int64_t> strideValues;
4548 llvm::SmallVector<int64_t> padValues;
4549 if (tosa::getConstShapeValues(adaptor.getKernel().getDefiningOp(),
4550 kernelValues) &&
4551 tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
4552 strideValues) &&
4553 tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues)) {
4554 return poolingInferReturnTypes(inputShape, kernelValues, strideValues,
4555 padValues, inferredReturnShapes);
4556 }
4557
4558 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4559 if (inputShape.hasRank()) {
4560 outputShape[0] = inputShape.getDimSize(0);
4561 outputShape[3] = inputShape.getDimSize(3);
4562 }
4563 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4564 return success();
4565}
4566
4567LogicalResult MaxPool2dOp::verify() {
4568 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4569 /* outType = */ getOutput().getType())))
4570 return failure();
4571
4572 if (failed(verifyPoolingOp(*this)))
4573 return failure();
4574
4575 return success();
4576}
4577
4578LogicalResult MaxPool2dAdaptiveOp::verify() {
4579 if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(),
4580 /* outType = */ getOutput().getType())))
4581 return failure();
4582
4583 AdaptivePoolingConstShapeValues values;
4585
4586 if (failed(verifyPoolingOpImpl(getOperation(), values.kernel, values.stride,
4587 values.pad, getInput(), getOutput())))
4588 return failure();
4589
4590 return success();
4591}
4592
4593LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
4594 MLIRContext *context, ::std::optional<Location> location,
4595 DepthwiseConv2DOp::Adaptor adaptor,
4596 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4597 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4598
4599 int64_t inputWidth = ShapedType::kDynamic;
4600 int64_t inputHeight = ShapedType::kDynamic;
4601 int64_t inputChannels = ShapedType::kDynamic;
4602
4603 int64_t weightWidth = ShapedType::kDynamic;
4604 int64_t weightHeight = ShapedType::kDynamic;
4605 int64_t depthChannels = ShapedType::kDynamic;
4606
4607 // Input shape describes input width/height and batch.
4608 ShapeAdaptor inputShape(adaptor.getInput().getType());
4609 if (inputShape.hasRank()) {
4610 outputShape[0] = inputShape.getDimSize(0);
4611 inputHeight = inputShape.getDimSize(1);
4612 inputWidth = inputShape.getDimSize(2);
4613 inputChannels = inputShape.getDimSize(3);
4614 }
4615
4616 // Weight shapes describes the filter width/height and the output channels.
4617 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4618 if (weightShape.hasRank()) {
4619 weightHeight = weightShape.getDimSize(0);
4620 weightWidth = weightShape.getDimSize(1);
4621 inputChannels = ShapedType::isDynamic(inputChannels)
4622 ? weightShape.getDimSize(2)
4623 : inputChannels;
4624 depthChannels = weightShape.getDimSize(3);
4625 }
4626
4627 // If both inputChannels and depthChannels are available we can determine
4628 // the output channels.
4629 if (ShapedType::isStatic(inputChannels) &&
4630 ShapedType::isStatic(depthChannels)) {
4631 outputShape[3] = inputChannels * depthChannels;
4632 }
4633
4634 // Bias shape can describe the output channels.
4635 ShapeAdaptor biasShape(adaptor.getBias().getType());
4636 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4637 int64_t bc = biasShape.getDimSize(0);
4638 if (bc != ShapedType::kDynamic && bc != 1)
4639 outputShape[3] = bc;
4640 }
4641
4642 llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
4643 llvm::ArrayRef<int64_t> padding = adaptor.getPad();
4644 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4645
4646 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4647 int64_t inputSize = inputHeight + padding[0] + padding[1];
4648 int64_t filterSize = (weightHeight - 1) * dilation[0] + 1;
4649 int64_t unstridedResult = inputSize - filterSize + 1;
4650 outputShape[1] = (unstridedResult - 1) / stride[0] + 1;
4651 }
4652
4653 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4654 int64_t inputSize = inputWidth + padding[2] + padding[3];
4655 int64_t filterSize = (weightWidth - 1) * dilation[1] + 1;
4656 int64_t unstridedResult = inputSize - filterSize + 1;
4657 outputShape[2] = (unstridedResult - 1) / stride[1] + 1;
4658 }
4659
4660 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4661 return success();
4662}
4663
4664LogicalResult DepthwiseConv2DOp::verify() {
4665 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() ||
4666 verifyConvOpErrorIf(*this).failed())
4667 return failure();
4668 return success();
4669}
4670
4671LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
4672 MLIRContext *context, ::std::optional<Location> location,
4673 TransposeConv2DOp::Adaptor adaptor,
4674 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4675 llvm::SmallVector<int64_t> outputShape(4, ShapedType::kDynamic);
4676
4677 int64_t inputWidth = ShapedType::kDynamic;
4678 int64_t inputHeight = ShapedType::kDynamic;
4679 int64_t weightWidth = ShapedType::kDynamic;
4680 int64_t weightHeight = ShapedType::kDynamic;
4681
4682 // Input shape describes input width/height and batch.
4683 ShapeAdaptor inputShape(adaptor.getInput().getType());
4684 if (inputShape.hasRank()) {
4685 outputShape[0] = ShapedType::isDynamic(outputShape[0])
4686 ? inputShape.getDimSize(0)
4687 : outputShape[0];
4688 inputHeight = inputShape.getDimSize(1);
4689 inputWidth = inputShape.getDimSize(2);
4690 }
4691
4692 // Weight shapes describes the filter width/height and the output channels.
4693 ShapeAdaptor weightShape(adaptor.getWeight().getType());
4694 if (weightShape.hasRank()) {
4695 outputShape[3] = ShapedType::isDynamic(outputShape[3])
4696 ? weightShape.getDimSize(0)
4697 : outputShape[3];
4698 weightHeight = weightShape.getDimSize(1);
4699 weightWidth = weightShape.getDimSize(2);
4700 }
4701
4702 // Bias shape can describe the output channels.
4703 ShapeAdaptor biasShape(adaptor.getBias().getType());
4704 if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
4705 int64_t bc = biasShape.getDimSize(0);
4706 if (bc != ShapedType::kDynamic && bc != 1)
4707 outputShape[3] = bc;
4708 }
4709
4710 llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
4711 llvm::ArrayRef<int64_t> stride = adaptor.getStride();
4712
4713 if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
4714 int64_t calculateSize =
4715 (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight;
4716 outputShape[1] =
4717 ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1];
4718 }
4719
4720 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
4721 int64_t calculateSize =
4722 (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth;
4723 outputShape[2] =
4724 ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2];
4725 }
4726
4727 inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
4728 return success();
4729}
4730
4731LogicalResult TransposeConv2DOp::verify() {
4732 if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
4733 return failure();
4734
4735 const llvm::ArrayRef<int64_t> strides = getStride();
4736 const int64_t strideY = strides[0];
4737 const int64_t strideX = strides[1];
4738
4739 if (strideY < 1 || strideX < 1)
4740 return emitOpError("expect all stride values to be >= 1, got [")
4741 << strides << "]";
4742
4743 const auto checkPadAgainstKernelDim =
4744 [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
4745 llvm::StringRef kernelDimName) -> LogicalResult {
4746 if (padValue <= -kernelDimSize)
4747 return emitOpError("expected ")
4748 << padName << " > -" << kernelDimName << ", but got: " << padName
4749 << "=" << padValue << " and " << kernelDimName << "="
4750 << kernelDimSize;
4751 return success();
4752 };
4753
4754 const llvm::ArrayRef<int64_t> padding = getOutPad();
4755 const int64_t outPadTop = padding[0];
4756 const int64_t outPadBottom = padding[1];
4757 const int64_t outPadLeft = padding[2];
4758 const int64_t outPadRight = padding[3];
4759
4760 const auto weightType =
4761 llvm::dyn_cast<RankedTensorType>(getWeight().getType());
4762
4763 if (weightType) {
4764 const int64_t kernelHeight = weightType.getDimSize(1);
4765 if (ShapedType::isStatic(kernelHeight)) {
4766 if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight,
4767 "out_pad_top", "KH")))
4768 return failure();
4769
4770 if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight,
4771 "out_pad_bottom", "KH")))
4772 return failure();
4773 }
4774
4775 const int64_t kernelWidth = weightType.getDimSize(2);
4776 if (ShapedType::isStatic(kernelWidth)) {
4777 if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth,
4778 "out_pad_left", "KW")))
4779 return failure();
4780
4781 if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth,
4782 "out_pad_right", "KW")))
4783 return failure();
4784 }
4785 }
4786
4787 // Rest of the checks depend on the output type being a RankedTensorType
4788 const auto outputType =
4789 llvm::dyn_cast<RankedTensorType>(getOutput().getType());
4790 if (!outputType)
4791 return success();
4792
4793 const auto inputType = llvm::dyn_cast<RankedTensorType>(getInput().getType());
4794 if (inputType && weightType) {
4795 const int64_t inputHeight = inputType.getDimSize(1);
4796 const int64_t kernelHeight = weightType.getDimSize(1);
4797 const int64_t outputHeight = outputType.getDimSize(1);
4798
4799 if (ShapedType::isStatic(inputHeight) &&
4800 ShapedType::isStatic(outputHeight)) {
4801 if (outputHeight !=
4802 (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight)
4803 return emitOpError(
4804 "dimension mismatch: expected OH == (IH - 1) * stride_y "
4805 "+ out_pad_top + out_pad_bottom + KH, but got ")
4806 << outputHeight << " != (" << inputHeight << " - 1) * "
4807 << strideY << " + " << outPadTop << " + " << outPadBottom
4808 << " + " << kernelHeight;
4809 }
4810
4811 const int64_t inputWidth = inputType.getDimSize(2);
4812 const int64_t kernelWidth = weightType.getDimSize(2);
4813 const int64_t outputWidth = outputType.getDimSize(2);
4814
4815 if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) {
4816 if (outputWidth !=
4817 (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth)
4818 return emitOpError(
4819 "dimension mismatch: expected OW == (IW - 1) * stride_x "
4820 "+ out_pad_left + out_pad_right + KW, but got ")
4821 << outputWidth << " != (" << inputWidth << " - 1) * " << strideX
4822 << " + " << outPadLeft << " + " << outPadRight << " + "
4823 << kernelWidth;
4824 }
4825 }
4826
4827 const auto biasType = llvm::dyn_cast<RankedTensorType>(getBias().getType());
4828
4829 if (!biasType)
4830 return success();
4831
4832 const int64_t biasChannels = biasType.getDimSize(0);
4833
4834 // Skip further checks if bias is dynamic
4835 if (biasChannels == ShapedType::kDynamic)
4836 return success();
4837
4838 const int64_t outputChannels = outputType.getDimSize(3);
4839 if (!ShapedType::isDynamic(outputChannels) &&
4840 biasChannels != outputChannels && biasChannels != 1)
4841 return emitOpError(
4842 "bias channels expected to be equal to output channels (")
4843 << outputChannels << ") or 1, got " << biasChannels;
4844
4845 return success();
4846}
4847
4848LogicalResult RescaleOp::verify() {
4849 auto inputType = llvm::dyn_cast<ShapedType>(getInput().getType());
4850 if (!inputType) {
4851 emitOpError("expect shaped tensor for input, got ") << getInput().getType();
4852 return failure();
4853 }
4854
4855 auto inputElementType =
4856 getStorageElementTypeOrSelf(inputType.getElementType());
4857 if (!mlir::isa<IntegerType>(inputElementType)) {
4858 emitOpError("expect input to have integer element type, got ")
4859 << inputElementType;
4860 return failure();
4861 }
4862
4863 auto outputType = llvm::dyn_cast<ShapedType>(getOutput().getType());
4864 if (!outputType) {
4865 emitOpError("expect shaped tensor for output, got ")
4866 << getOutput().getType();
4867 return failure();
4868 }
4869
4870 auto outputElementType =
4871 getStorageElementTypeOrSelf(outputType.getElementType());
4872 if (!mlir::isa<IntegerType>(outputElementType)) {
4873 emitOpError("expect output to have integer element type, got ")
4874 << outputElementType;
4875 return failure();
4876 }
4877
4878 if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
4879 .failed())
4880 return failure();
4881
4882 if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
4883 .failed())
4884 return failure();
4885
4886 FailureOr<int64_t> maybeIZp = getInputZeroPoint();
4887 if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
4888 return failure();
4889
4890 FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
4891 if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
4892 return failure();
4893
4894 auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
4895 if (!multiplierType) {
4896 emitOpError("expect shaped tensor for multiplier, got ")
4897 << getMultiplier().getType();
4898 return failure();
4899 }
4900
4901 auto shiftType = llvm::dyn_cast<ShapedType>(getShift().getType());
4902 if (!shiftType) {
4903 emitOpError("expect shaped tensor for shift, got ") << getShift().getType();
4904 return failure();
4905 }
4906
4907 // multiplier element type must be i32 for scale32 = true
4908 if (getScale32() && !multiplierType.getElementType().isInteger(32)) {
4909 emitOpError("expect i32 element type for multiplier for scale32=true, got ")
4910 << multiplierType.getElementType();
4911 return failure();
4912 }
4913
4914 // multiplier element type must be i16 for scale32 = false
4915 if (!getScale32() && !multiplierType.getElementType().isInteger(16)) {
4917 "expect i16 element type for multiplier for scale32=false, got ")
4918 << multiplierType.getElementType();
4919 return failure();
4920 }
4921
4922 if (!inputType.hasRank())
4923 return success();
4924
4925 // multiplier/shift must have shape = {numChannels},
4926 // where numChannel is 1 if per_channel = false
4927 // otherwise numChannel is dimension in input shape's last axis
4928 int64_t numChannels = 1;
4929 if (getPerChannel()) {
4930 if (inputType.getRank() < 1) {
4931 emitOpError("requires input to be at least rank 1 when per_channel is "
4932 "true, but got rank ")
4933 << inputType.getRank();
4934 return failure();
4935 }
4936 numChannels = inputType.getDimSize(inputType.getRank() - 1);
4937 }
4938
4939 if (!multiplierType.hasRank())
4940 return success();
4941
4942 ArrayRef<int64_t> multiplierShape = multiplierType.getShape();
4943 // multiplier input has rank 1 by dialect definition
4944 if (multiplierShape[0] != ShapedType::kDynamic &&
4945 multiplierShape[0] != numChannels) {
4946 emitOpError("expect shape of { ")
4947 << numChannels << " } for multiplier input, got { "
4948 << multiplierShape[0] << " }";
4949 return failure();
4950 }
4951
4952 if (!shiftType.hasRank())
4953 return success();
4954
4955 ArrayRef<int64_t> shiftShape = shiftType.getShape();
4956 // shift input has rank 1 by dialect definition
4957 if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) {
4958 emitOpError("expect shape of { ")
4959 << numChannels << " } for shift input, got { " << shiftShape[0] << " }";
4960 return failure();
4961 }
4962
4963 return success();
4964}
4965
4966LogicalResult RescaleOp::inferReturnTypeComponents(
4967 MLIRContext *context, ::std::optional<Location> location,
4968 RescaleOp::Adaptor adaptor,
4969 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4970 ShapeAdaptor inputShape(adaptor.getInput().getType());
4971 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4972 return success();
4973}
4974
4975LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
4976 MLIRContext *context, ::std::optional<Location> location,
4977 CastFromBlockScaledOp::Adaptor adaptor,
4978 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
4979 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
4980 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
4981 return success();
4982}
4983
4984LogicalResult CastFromBlockScaledOp::verify() {
4985 const Type inputDataType = getInputData().getType();
4986 const Type outputDataType = getResult().getType();
4987 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
4988 return emitOpError() << "require compatible shapes for input_data ("
4989 << inputDataType << ") and " << "output_data ("
4990 << outputDataType << ")";
4991
4992 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
4993
4994 if (inputDataShape.hasRank()) {
4995 const unsigned int blockSize =
4996 BlockSizeAttr::getBlockSizeValue(getBlockSize());
4997 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
4998 return emitOpError("expect block size to be 32, got ") << blockSize;
4999 const int64_t inputDataLastDim =
5000 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5001 if (inputDataLastDim % blockSize != 0)
5002 return emitOpError() << "expect last dimension of input_data ("
5003 << inputDataLastDim
5004 << ") to be divisible by block_size (" << blockSize
5005 << ")";
5006
5007 const Type inputScaleType = getInputScale().getType();
5008 const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
5009
5010 if (inputScaleShape.hasRank()) {
5011 SmallVector<int64_t> inputDataDims, inputScaleDims;
5012 inputDataShape.getDims(inputDataDims);
5013 inputScaleShape.getDims(inputScaleDims);
5014
5015 if (inputDataDims.size() != inputScaleDims.size() ||
5017 ArrayRef<int64_t>(inputDataDims).drop_back(1),
5018 ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
5019 return emitOpError()
5020 << "require compatible shapes for input_data (" << inputDataType
5021 << ") and " << "input_scale (" << inputScaleType
5022 << ") except for the last dimension";
5023
5024 const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
5025 inputScaleDims.back()};
5026 if (ShapedType::isStatic(inputDataLastDim) &&
5027 failed(verifyCompatibleDims(dimsToCheck)))
5028 return emitOpError()
5029 << "expect last dimension of input_scale ("
5030 << inputScaleDims.back()
5031 << ") to be equal to last dimension of input_data / block_size ("
5032 << inputDataDims.back() / blockSize << ")";
5033 }
5034 }
5035
5036 return success();
5037}
5038
5039LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
5040 MLIRContext *context, ::std::optional<Location> location,
5041 CastToBlockScaledOp::Adaptor adaptor,
5042 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5043 const ShapeAdaptor inputShape(adaptor.getInputData().getType());
5044 inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
5045 if (!inputShape.hasRank())
5046 return success();
5047
5048 // Calculate output_scale shape if ranked input provided
5049 SmallVector<int64_t> outputScaleShape;
5050 inputShape.getDims(outputScaleShape);
5051 const int64_t lastDimLoc = inputShape.getRank() - 1;
5052 const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
5053 if (ShapedType::isStatic(lastDimSize)) {
5054 const unsigned int blockSize =
5055 BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
5056 outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
5057 }
5058 inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
5059 return success();
5060}
5061
5062LogicalResult CastToBlockScaledOp::verify() {
5063 const Type inputDataType = getInputData().getType();
5064 const Type outputDataType = getResult(0).getType();
5065 if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
5066 return emitOpError() << "require compatible shapes for input_data ("
5067 << inputDataType << ") and " << "output_data ("
5068 << outputDataType << ")";
5069
5070 const unsigned int blockSize =
5071 BlockSizeAttr::getBlockSizeValue(getBlockSize());
5072 if (blockSize != BlockSizeAttr::getBlockSizeValue(BlockSize::BLOCK_SIZE_32))
5073 return emitOpError("expect block size to be 32, got ") << blockSize;
5074 const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
5075 if (inputDataShape.hasRank()) {
5076 const int64_t inputDataLastDim =
5077 inputDataShape.getDimSize(inputDataShape.getRank() - 1);
5078 if (ShapedType::isStatic(inputDataLastDim) &&
5079 inputDataLastDim % blockSize != 0)
5080 return emitOpError() << "expect last dimension of input_data ("
5081 << inputDataLastDim
5082 << ") to be divisible by block_size (" << blockSize
5083 << ")";
5084 }
5085
5086 const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
5087 const Type outputScaleType = getResult(1).getType();
5088 const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
5089 if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
5090 SmallVector<int64_t> outputDataDims, outputScaleDims;
5091 outputDataShape.getDims(outputDataDims);
5092 outputScaleShape.getDims(outputScaleDims);
5093
5094 if (outputDataDims.size() != outputScaleDims.size() ||
5096 ArrayRef<int64_t>(outputDataDims).drop_back(1),
5097 ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
5098 return emitOpError() << "require compatible shapes for output_data ("
5099 << outputDataType << ") and " << "output_scale ("
5100 << outputScaleType
5101 << ") except for the last dimension";
5102
5103 const int64_t outputDataLastDim = outputDataDims.back();
5104 const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
5105 outputScaleDims.back()};
5106 if (ShapedType::isStatic(outputDataLastDim) &&
5107 failed(verifyCompatibleDims(dimsToCheck)))
5108 return emitOpError()
5109 << "expect last dimension of output_scale ("
5110 << outputScaleDims.back()
5111 << ") to be equal to last dimension of output_data / block_size ("
5112 << outputDataDims.back() / blockSize << ")";
5113 }
5114
5115 return success();
5116}
5117
5118LogicalResult IfOp::inferReturnTypeComponents(
5119 MLIRContext *context, ::std::optional<Location> location,
5120 IfOp::Adaptor adaptor,
5121 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5122 llvm::SmallVector<tosa::YieldOp> yieldOps;
5123 for (Region *region : adaptor.getRegions()) {
5124 for (auto &block : *region)
5125 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5126 yieldOps.push_back(returnOp);
5127 }
5128
5129 if (yieldOps.empty())
5130 return failure();
5131
5132 // Get the initial type information for the yield op.
5133 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5134 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5135 for (auto operand : yieldOps.front().getOperands()) {
5136 resultKnowledge.push_back(
5137 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5138 }
5139
5140 for (auto yieldOp : yieldOps) {
5141 if (resultKnowledge.size() != yieldOp.getNumOperands())
5142 return failure();
5143
5144 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5145 int32_t index = it.index();
5146 auto meet = ValueKnowledge::meet(
5147 resultKnowledge[index],
5148 ValueKnowledge::getKnowledgeFromType(it.value().getType()));
5149 if (!meet)
5150 continue;
5151 resultKnowledge[index] = meet;
5152 }
5153 }
5154
5155 for (const ValueKnowledge &result : resultKnowledge) {
5156 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5157 }
5158
5159 return success();
5160}
5161
5162LogicalResult WhileOp::inferReturnTypeComponents(
5163 MLIRContext *context, ::std::optional<Location> location,
5164 WhileOp::Adaptor adaptor,
5165 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
5166 llvm::SmallVector<tosa::YieldOp> yieldOps;
5167 for (auto &block : adaptor.getBodyGraph())
5168 if (auto returnOp = dyn_cast<tosa::YieldOp>(block.getTerminator()))
5169 yieldOps.push_back(returnOp);
5170
5171 // TOSA's while must have a tosa.yield as its terminator. If not found this
5172 // tosa.while is invalid.
5173 if (yieldOps.empty())
5174 return failure();
5175
5176 // Get the initial type information from the operand types.
5177 llvm::SmallVector<ValueKnowledge> resultKnowledge;
5178 resultKnowledge.reserve(yieldOps.front().getNumOperands());
5179 for (auto operand : yieldOps.front().getOperands()) {
5180 resultKnowledge.push_back(
5181 ValueKnowledge::getKnowledgeFromType(operand.getType()));
5182 }
5183
5184 for (auto yieldOp : yieldOps) {
5185 if (resultKnowledge.size() != yieldOp.getNumOperands())
5186 return failure();
5187
5188 for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
5189 int32_t index = it.index();
5190 if (auto meet = ValueKnowledge::meet(
5191 resultKnowledge[index],
5192 ValueKnowledge::getKnowledgeFromType(it.value().getType()))) {
5193 resultKnowledge[index] = meet;
5194 }
5195 }
5196 }
5197
5198 for (const ValueKnowledge &result : resultKnowledge) {
5199 inferredReturnShapes.push_back(result.getShapedTypeComponents());
5200 }
5201
5202 return success();
5203}
5204
5205std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
5206 if (auto vt = llvm::dyn_cast<VectorType>(getType()))
5207 return llvm::to_vector<4>(vt.getShape());
5208 return std::nullopt;
5209}
5210
5212 Block::BlockArgListType blocksArgs,
5213 ValueRange initializers,
5214 StringRef prefix = "") {
5215 assert(blocksArgs.size() == initializers.size() &&
5216 "expected same length of arguments and initializers");
5217 if (initializers.empty())
5218 return;
5219
5220 parser << prefix << '(';
5221 llvm::interleaveComma(
5222 llvm::zip(blocksArgs, initializers), parser,
5223 [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
5224 parser << ")";
5225}
5226
5227// parse and print of IfOp refer to the implementation of SCF dialect.
5228ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
5229 // Create the regions for 'then'.
5230 result.regions.reserve(2);
5231 Region *thenRegion = result.addRegion();
5232 Region *elseRegion = result.addRegion();
5233
5234 OpAsmParser::UnresolvedOperand cond;
5235
5236 if (parser.parseOperand(cond))
5237 return failure();
5238
5239 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5240 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5241
5242 // Parse the optional block arguments
5243 OptionalParseResult listResult =
5244 parser.parseOptionalAssignmentList(regionArgs, operands);
5245 if (listResult.has_value() && failed(listResult.value()))
5246 return failure();
5247
5248 // Parse a colon.
5249 if (failed(parser.parseColon()))
5250 return parser.emitError(parser.getCurrentLocation(),
5251 "expected type for condition operand");
5252
5253 // Parse the type of the condition operand
5254 Type condType;
5255 if (failed(parser.parseType(condType)))
5256 return parser.emitError(parser.getCurrentLocation(),
5257 "expected type for condition operand");
5258
5259 // Resolve operand with provided type
5260 if (failed(parser.resolveOperand(cond, condType, result.operands)))
5261 return failure();
5262
5263 // Parse optional block arg types
5264 if (listResult.has_value()) {
5265 FunctionType functionType;
5266
5267 if (failed(parser.parseType(functionType)))
5268 return parser.emitError(parser.getCurrentLocation())
5269 << "expected list of types for block arguments "
5270 << "followed by arrow type and list of return types";
5271
5272 result.addTypes(functionType.getResults());
5273
5274 if (functionType.getNumInputs() != operands.size()) {
5275 return parser.emitError(parser.getCurrentLocation())
5276 << "expected as many input types as operands " << "(expected "
5277 << operands.size() << " got " << functionType.getNumInputs()
5278 << ")";
5279 }
5280
5281 // Resolve input operands.
5282 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5283 parser.getCurrentLocation(),
5284 result.operands)))
5285 return failure();
5286 } else {
5287 // Parse optional results type list.
5288 if (parser.parseOptionalArrowTypeList(result.types))
5289 return failure();
5290 }
5291
5292 // Parse the 'then' region.
5293 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
5294 return failure();
5295
5296 // If we find an 'else' keyword then parse the 'else' region.
5297 if (!parser.parseOptionalKeyword("else")) {
5298 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
5299 return failure();
5300 }
5301
5302 // Parse the optional attribute list.
5303 if (parser.parseOptionalAttrDict(result.attributes))
5304 return failure();
5305 return success();
5306}
5307
5308void IfOp::print(OpAsmPrinter &p) {
5309 p << " " << getCondition();
5310
5311 printInitializationList(p, getThenGraph().front().getArguments(),
5312 getInputList(), " ");
5313 p << " : ";
5314 p << getCondition().getType();
5315
5316 if (!getInputList().empty()) {
5317 p << " (";
5318 llvm::interleaveComma(getInputList().getTypes(), p);
5319 p << ")";
5320 }
5321 p.printArrowTypeList(getResultTypes());
5322 p << " ";
5323
5324 p.printRegion(getThenGraph());
5325
5326 // Print the 'else' regions if it exists and has a block.
5327 auto &elseRegion = getElseGraph();
5328 if (!elseRegion.empty()) {
5329 p << " else ";
5330 p.printRegion(elseRegion);
5331 }
5332
5333 p.printOptionalAttrDict((*this)->getAttrs());
5334}
5335
5336LogicalResult IfOp::verify() {
5337 if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(),
5338 "'then_graph' arguments", getInputList(),
5339 "'input_list'")
5340 .failed())
5341 return failure();
5342
5343 if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(),
5344 "'else_graph' arguments", getInputList(),
5345 "'input_list'")
5346 .failed())
5347 return failure();
5348
5349 // MLIR will verify the absence of the terminator for us if otherwise.
5350 if (getThenGraph().front().mightHaveTerminator()) {
5351 auto thenYield =
5352 dyn_cast<tosa::YieldOp>(getThenGraph().front().getTerminator());
5353 if (thenYield && errorIfTypeOrShapeMismatch(
5354 *this, thenYield.getInputs(), "'then_graph' results",
5355 getOutputList(), "'output_list'")
5356 .failed())
5357 return failure();
5358 }
5359
5360 // MLIR will verify the absence of the terminator for us if otherwise.
5361 if (getElseGraph().front().mightHaveTerminator()) {
5362 auto elseYield =
5363 dyn_cast<tosa::YieldOp>(getElseGraph().front().getTerminator());
5364 if (elseYield && errorIfTypeOrShapeMismatch(
5365 *this, elseYield.getInputs(), "'else_graph' results",
5366 getOutputList(), "'output_list'")
5367 .failed())
5368 return failure();
5369 }
5370
5371 auto condType = getCondition().getType();
5372 if (errorIfShapeNotSizeOne(*this, condType).failed())
5373 return emitOpError() << "'condition' must be a size 1 tensor, got "
5374 << condType;
5375
5376 return success();
5377}
5378
5379LogicalResult WhileOp::verify() {
5380 if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'",
5381 getOutputList(), "'output_list'")
5382 .failed())
5383 return failure();
5384
5385 if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(),
5386 "'cond_graph' arguments", getInputList(),
5387 "'input_list'")
5388 .failed())
5389 return failure();
5390
5391 if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(),
5392 "'body_graph' arguments", getInputList(),
5393 "'input_list'")
5394 .failed())
5395 return failure();
5396
5397 if (getBodyGraph().front().mightHaveTerminator()) {
5398 auto bodyYield =
5399 dyn_cast<tosa::YieldOp>(getBodyGraph().front().getTerminator());
5400 if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(),
5401 "'body_graph' results",
5402 getInputList(), "'input_list'")
5403 .failed())
5404 return failure();
5405 }
5406
5407 // Condition block output must be a single element tensor with a single bool
5408 // value.
5409 if (!getCondGraph().front().mightHaveTerminator())
5410 return success();
5411
5412 auto condYield =
5413 dyn_cast<tosa::YieldOp>(getCondGraph().front().getTerminator());
5414 if (!condYield)
5415 return success();
5416
5417 if (condYield.getInputs().size() != 1)
5418 return emitOpError() << "require 'cond_graph' only have one result";
5419
5420 auto condOutType = condYield.getInputs()[0].getType();
5421 if (errorIfShapeNotSizeOne(*this, condOutType).failed())
5422 return emitOpError() << "'cond_graph' result must be a size 1 tensor, got "
5423 << condOutType;
5424
5425 if (!getElementTypeOrSelf(condOutType).isInteger(1))
5426 return emitOpError() << "'cond_graph' result must be a boolean tensor, got "
5427 << condOutType;
5428
5429 return success();
5430}
5431
5432LogicalResult ReverseOp::verify() {
5433 if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(),
5434 /* outType = */ getOutput().getType())
5435 .failed())
5436 return failure();
5437 TensorType inputType = getInput1().getType();
5438 TensorType outputType = getOutput().getType();
5439 int32_t reverseAxis = getAxis();
5440
5441 if (reverseAxis < 0)
5442 return emitOpError("expected non-negative reverse axis");
5443 if (inputType.hasRank()) {
5444 int64_t inputRank = inputType.getRank();
5445 // We allow for a special case where the input/output shape has rank 0 and
5446 // axis is also 0.
5447 if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0))
5448 return emitOpError("expect input tensor rank (")
5449 << inputRank << ") to be larger than reverse axis (" << reverseAxis
5450 << ")";
5451 }
5452 if (outputType.hasRank()) {
5453 int64_t outputRank = outputType.getRank();
5454 if (inputType.hasRank() && outputRank != inputType.getRank())
5455 return emitOpError(
5456 "expect output tensor rank to be equal to input tensor rank");
5457 if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0))
5458 return emitOpError("expect output tensor rank (")
5459 << outputRank << ") to be larger than reverse axis ("
5460 << reverseAxis << ")";
5461 }
5462 return success();
5463}
5464
5465LogicalResult tosa::SelectOp::verify() {
5466 // verify input2 and input3 have same element type as output
5467 if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(),
5468 /* outType = */ getOutput().getType())
5469 .failed() ||
5470 verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(),
5471 /* outType = */ getOutput().getType())
5472 .failed()) {
5473 return failure();
5474 }
5475 // verify input1 has element type of bool
5476 auto predicateType = llvm::dyn_cast<ShapedType>(getPred().getType());
5477 if (!predicateType) {
5478 return emitOpError("expect shaped tensor for input1, got ")
5479 << getInput1().getType();
5480 }
5481 auto predicateElementType = predicateType.getElementType();
5482 if (!predicateElementType.isInteger(1)) {
5483 return emitOpError("expect element type of bool for input1, got ")
5484 << predicateElementType;
5485 }
5486
5487 return success();
5488}
5489
5490LogicalResult tosa::VariableReadOp::verify() {
5491 if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'")
5492 .failed())
5493 return failure();
5494
5495 return success();
5496}
5497
5498LogicalResult tosa::VariableWriteOp::verify() {
5499 if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'")
5500 .failed())
5501 return failure();
5502
5503 return success();
5504}
5505
5506// parse and print of WhileOp refer to the implementation of SCF dialect.
5507ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
5508 SmallVector<OpAsmParser::Argument, 4> regionArgs;
5509 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
5510 Region *cond = result.addRegion();
5511 Region *body = result.addRegion();
5512
5513 OptionalParseResult listResult =
5514 parser.parseOptionalAssignmentList(regionArgs, operands);
5515 if (listResult.has_value() && failed(listResult.value()))
5516 return failure();
5517
5518 FunctionType functionType;
5519 SMLoc typeLoc = parser.getCurrentLocation();
5520 if (failed(parser.parseColonType(functionType)))
5521 return failure();
5522
5523 result.addTypes(functionType.getResults());
5524
5525 if (functionType.getNumInputs() != operands.size()) {
5526 return parser.emitError(typeLoc)
5527 << "expected as many input types as operands " << "(expected "
5528 << operands.size() << " got " << functionType.getNumInputs() << ")";
5529 }
5530
5531 // Resolve input operands.
5532 if (failed(parser.resolveOperands(operands, functionType.getInputs(),
5533 parser.getCurrentLocation(),
5534 result.operands)))
5535 return failure();
5536
5537 // Propagate the types into the region arguments.
5538 for (size_t i = 0, e = regionArgs.size(); i != e; ++i)
5539 regionArgs[i].type = functionType.getInput(i);
5540
5541 return failure(parser.parseRegion(*cond, regionArgs) ||
5542 parser.parseKeyword("do") || parser.parseRegion(*body) ||
5543 parser.parseOptionalAttrDictWithKeyword(result.attributes));
5544}
5545
5546void WhileOp::print(OpAsmPrinter &parser) {
5547 printInitializationList(parser, getCondGraph().front().getArguments(),
5548 getInputList(), " ");
5549 parser << " : ";
5550 parser.printFunctionalType(getInputList().getTypes(),
5551 getResults().getTypes());
5552 parser << ' ';
5553 parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false);
5554 parser << " do ";
5555 parser.printRegion(getBodyGraph());
5556 parser.printOptionalAttrDictWithKeyword((*this)->getAttrs());
5557}
5558
5559// Create a rank-1 const tensor for zero point of the source tensor.
5560std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
5561 Location loc,
5562 Type srcElemType,
5563 int64_t zp) {
5564 srcElemType = getStorageElementTypeOrSelf(srcElemType);
5565 auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
5566 if (llvm::isa<FloatType>(srcElemType)) {
5567 auto zpAttr = DenseElementsAttr::get(
5568 zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
5569 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5570 }
5571 if (llvm::isa<IntegerType>(srcElemType)) {
5572 auto zpAttr =
5573 DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp));
5574 return tosa::ConstOp::create(builder, loc, zpType, zpAttr);
5575 }
5576 llvm::errs() << "zero point is not allowed for unsupported data types\n";
5577 return std::nullopt;
5578}
5579
5580//===----------------------------------------------------------------------===//
5581// TOSA Shape and Shape Operators Helper functions.
5582//===----------------------------------------------------------------------===//
5583
5585 return mlir::isa<tosa::shapeType>(t);
5586}
5587
5588LogicalResult
5589mlir::tosa::shapeType::verify(function_ref<InFlightDiagnostic()> emitError,
5590 int rank) {
5591 if (rank < 0)
5592 return emitError() << "invalid rank (must be >= 0): " << rank;
5593 return success();
5594}
5595
5597 for (auto v : op->getOperands()) {
5598 if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) {
5599 Operation *definingOp = v.getDefiningOp();
5600 if (!definingOp || !definingOp->hasTrait<TosaShapeOperator>()) {
5601 return op->emitOpError("shape operand is not compile time resolvable");
5602 }
5603 }
5604 }
5605 return success();
5606}
5607
5608LogicalResult
5610 if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
5611 return failure();
5612
5613 // delegate function that returns rank of shape type
5614 auto getRank = [](const Type type) {
5615 return mlir::cast<mlir::tosa::shapeType>(type).getRank();
5616 };
5617 auto operandTypes = op->getOperandTypes();
5618 auto resultTypes = op->getResultTypes();
5619
5620 auto rank = getRank(*op->getOperandTypes().begin());
5621 for (auto type : operandTypes) {
5622 if (getRank(type) != rank) {
5623 return op->emitOpError("operands don't have matching ranks");
5624 }
5625 }
5626 for (auto type : resultTypes) {
5627 if (getRank(type) != rank) {
5628 return op->emitOpError("result shape has different rank than operands");
5629 }
5630 }
5631 return success();
5632}
5633
5634//===----------------------------------------------------------------------===//
5635// TOSA Shape Operators verify functions.
5636//===----------------------------------------------------------------------===//
5637
5638LogicalResult tosa::ConstShapeOp::verify() {
5639 // check one dimensional rank
5640 auto valuesRank = getValues().getType().getRank();
5641 if (valuesRank != 1)
5642 return emitOpError("expect elements in attribute values with rank 1");
5643 // check that number of elements in values attr equal to rank of result shape
5644 auto count = getValues().getNumElements();
5645 auto rank = (cast<tosa::shapeType>(getResult().getType())).getRank();
5646 if (count != rank && (count != 1 || rank != 0)) {
5647 return emitOpError("expect number of elements in attribute values (")
5648 << count << ") to be equal to the rank (" << rank
5649 << ") for the result shape type";
5650 }
5651 return success();
5652}
5653
5654LogicalResult tosa::DimOp::verify() {
5655 const tosa::shapeType outShapeType =
5656 cast<tosa::shapeType>(getResult().getType());
5657 if (outShapeType.getRank() != 1)
5658 return emitOpError("expect output shape type to contain one element, got ")
5659 << outShapeType;
5660
5661 const ShapeAdaptor inputType(getInput1().getType());
5662 if (inputType.hasRank()) {
5663 const int64_t inputRank = inputType.getRank();
5664 const int64_t axis = getAxisAttr().getInt();
5665 if (axis < 0 || axis >= inputRank)
5666 return emitOpError("expect axis to be in the range [0, ")
5667 << inputRank << "), got " << axis;
5668 }
5669 return success();
5670}
5671
5672LogicalResult tosa::ConcatShapeOp::verify() {
5673 const tosa::shapeType outShapeType =
5674 cast<tosa::shapeType>(getResult().getType());
5675 const int64_t outputRank = outShapeType.getRank();
5676 const Operation::operand_range inputList = getInput();
5677
5678 if (inputList.size() == 0)
5679 return emitOpError("requires at least one input shape");
5680
5681 if (llvm::any_of(inputList, [](Value v) {
5682 return cast<tosa::shapeType>(v.getType()).getRank() == 0;
5683 }))
5684 return emitOpError("requires all inputs shapes have a rank greater than 0");
5685
5686 const int64_t inputsRank =
5687 llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
5688 const tosa::shapeType inShapeType =
5689 cast<tosa::shapeType>(input.getType());
5690 return acc + inShapeType.getRank();
5691 });
5692 if (outputRank != inputsRank)
5693 return emitOpError("requires output shape rank to be equal to the sum of "
5694 "the input shape ranks (")
5695 << inputsRank << "), got " << outputRank;
5696
5697 return success();
5698}
5699
5700LogicalResult tosa::SliceShapeOp::verify() {
5701 std::optional<int32_t> start;
5702 DenseIntElementsAttr startAttr;
5703 if (matchPattern(getStart(), m_Constant(&startAttr)))
5704 start = startAttr.getValues<int32_t>()[0];
5705 if (start && start.value() < 0)
5706 return emitOpError("expected non-negative start index, got ")
5707 << start.value();
5708
5709 std::optional<int32_t> size;
5710 DenseIntElementsAttr sizeAttr;
5711 if (matchPattern(getSize(), m_Constant(&sizeAttr)))
5712 size = sizeAttr.getValues<int32_t>()[0];
5713 if (size && size.value() <= 0)
5714 return emitOpError("expected positive size, got ") << size.value();
5715
5716 if (!size)
5717 return success();
5718
5719 const tosa::shapeType outShapeType =
5720 cast<tosa::shapeType>(getResult().getType());
5721 const int64_t outputRank = outShapeType.getRank();
5722 if (outputRank != size)
5723 return emitOpError(
5724 "expected output type size to be equal to size attribute, got ")
5725 << outputRank << " vs " << size.value();
5726
5727 if (!start)
5728 return success();
5729
5730 const tosa::shapeType inShapeType =
5731 cast<tosa::shapeType>(getInput().getType());
5732 const int64_t inputRank = inShapeType.getRank();
5733 const int64_t sliceSize = start.value() + size.value();
5734 if (sliceSize > inputRank)
5735 return emitOpError("expected start + size to be less than or equal to "
5736 "input shape rank (")
5737 << inputRank << "), got " << sliceSize;
5738
5739 return success();
5740}
5741
5742//===----------------------------------------------------------------------===//
5743// TOSA Attribute Definitions.
5744//===----------------------------------------------------------------------===//
5745
5746#define GET_ATTRDEF_CLASSES
5747#include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc"
5748
5749//===----------------------------------------------------------------------===//
5750// TOSA Type Definitions.
5751//===----------------------------------------------------------------------===//
5752#define GET_TYPEDEF_CLASSES
5753#include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc"
5754
5755//===----------------------------------------------------------------------===//
5756// TOSA Operator Definitions.
5757//===----------------------------------------------------------------------===//
5758
5759#define GET_OP_CLASSES
5760#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void printInitializationList(OpAsmPrinter &p, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix="")
Prints the initialization list in the form of <prefix>(inner = outer, inner2 = outer2,...
Definition SCF.cpp:496
true
Given two iterators into the same block, return "true" if a is before `b.
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
static void printShapeToDiagnostic(InFlightDiagnostic &diag, ArrayRef< int64_t > shape)
Definition TosaOps.cpp:653
static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b)
The tosa.matmul op is also intended to be generated where a fully_connected op must be constructed wh...
Definition TosaOps.cpp:1411
static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, StringRef aName="input", StringRef bName="output")
Definition TosaOps.cpp:1037
LogicalResult inferConvReturnTypeComponents(AdaptorT adaptor, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4266
static SmallVector< int64_t > convertToMlirShape(ArrayRef< int64_t > shape)
Definition TosaOps.cpp:139
static LogicalResult ReduceInferReturnTypes(ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3814
static void buildAvgPool2dAdaptiveOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseI64ArrayAttr kernel, DenseI64ArrayAttr stride, DenseI64ArrayAttr pad, TypeAttr accType)
This builder mirrors avg_pool2d quant-info handling and materializes kernel/stride/pad as const_shape...
Definition TosaOps.cpp:1484
static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name)
Definition TosaOps.cpp:595
static LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type)
Definition TosaOps.cpp:1001
#define REDUCE_SHAPE_INFER(OP)
Definition TosaOps.cpp:3839
static LogicalResult verifyConvOp(T op)
Definition TosaOps.cpp:702
static LogicalResult verifyAvgPoolCommonTypeAndZpChecks(T op)
Definition TosaOps.cpp:1200
static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name)
Definition TosaOps.cpp:1010
static LogicalResult poolingInferReturnTypes(ShapeAdaptor inputShape, ArrayRef< int64_t > kernel, ArrayRef< int64_t > stride, ArrayRef< int64_t > pad, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:4029
static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings)
This builder is called on TOSA pad operator that needs to create its own OptionalAttr quantization_at...
Definition TosaOps.cpp:1571
static LogicalResult verifyPoolingOpImpl(Operation *op, ArrayRef< int64_t > kernel, ArrayRef< int64_t > strides, ArrayRef< int64_t > padding, Value input, Value output)
Definition TosaOps.cpp:1102
static std::optional< int64_t > idivCheck(const int64_t lhs, const int64_t rhs)
Definition TosaOps.cpp:578
static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue)
Definition TosaOps.cpp:1585
LogicalResult verifyConvOutputSize(Operation *op, const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName)
Definition TosaOps.cpp:665
static LogicalResult verifyReduceOp(T op)
Definition TosaOps.cpp:3864
#define NARY_SHAPE_INFER(OP)
Definition TosaOps.cpp:3932
#define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND)
Definition TosaOps.cpp:3136
static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType)
Handles tosa.transpose_conv2d which has outpad and output shape attributes.
Definition TosaOps.cpp:1389
static void extractAdaptivePoolingConstShapeOperands(T op, AdaptivePoolingConstShapeValues &values)
Definition TosaOps.cpp:1258
static LogicalResult verifyConvOpErrorIf(T op)
Definition TosaOps.cpp:856
static FailureOr< int64_t > getZeroPoint(Value val, bool signExtend)
Definition TosaOps.cpp:3055
static constexpr bool IsSupportedAdaptivePoolConstShapeVerifyOp
Definition TosaOps.cpp:1251
LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, const int64_t newDim, const StringRef operandName, const StringRef dimName)
Definition TosaOps.cpp:638
static LogicalResult verifyConvOpModes(T op)
Definition TosaOps.cpp:807
static LogicalResult NAryInferReturnTypes(const ValueShapeRange &operands, SmallVectorImpl< ShapedTypeComponents > &inferredReturnShapes)
Definition TosaOps.cpp:3920
static Type getStorageElementTypeOrSelf(Type type)
Definition TosaOps.cpp:584
#define COMPATIBLE_RETURN_TYPES(OP)
Definition TosaOps.cpp:3830
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector< int64_t > &outShape)
Definition TosaOps.cpp:1629
static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input)
This builder is called on single-parameter negate operator to construct input and output zero points ...
Definition TosaOps.cpp:1531
static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType)
This builder is called on all convolution operators except TransposeConv, which has specialized outpu...
Definition TosaOps.cpp:1365
static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType)
Both the tosa.avg_pool2d and unary ops use the same UnaryOpQuantizationAttr but avg_pool operator has...
Definition TosaOps.cpp:1440
static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2)
Definition TosaOps.cpp:959
static FailureOr< int64_t > getConstantScalarIntValue(Value val)
Definition TosaOps.cpp:3081
static FailureOr< int64_t > resolveBroadcastDim(const int64_t dim1, const int64_t dim2)
Definition TosaOps.cpp:1615
static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand)
Definition TosaOps.cpp:3094
static LogicalResult verifyPoolingOp(T op)
Definition TosaOps.cpp:1194
static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName)
Definition TosaOps.cpp:1715
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static void updateIfDynamic(int64_t &current, int64_t candidate)
Definition TosaOps.cpp:4068
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4161
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4190
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4135
ConvInferShapeAdaptor(Conv2DBlockScaledOp::Adaptor adaptor)
Definition TosaOps.cpp:4132
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4081
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4096
ConvInferShapeAdaptor(Conv2DOp::Adaptor adaptor)
Definition TosaOps.cpp:4078
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4114
void inferWeightShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &weightSpatial)
Definition TosaOps.cpp:4231
ConvInferShapeAdaptor(Conv3DOp::Adaptor adaptor)
Definition TosaOps.cpp:4211
void inferInputShape(SmallVectorImpl< int64_t > &outputShape, SmallVectorImpl< int64_t > &inputSpatial)
Definition TosaOps.cpp:4214
LogicalResult getSpatialParameters(SmallVector< int64_t > &padValues, SmallVector< int64_t > &strideValues, SmallVector< int64_t > &dilationValues)
Definition TosaOps.cpp:4251
virtual ParseResult parseOptionalRBrace()=0
Parse a } token if present.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
virtual void printAttribute(Attribute attr)
void printArrowTypeList(TypeRange &&types)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MutableArrayRef< BlockArgument > BlockArgListType
Definition Block.h:95
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:232
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:258
IntegerType getI32Type()
Definition Builders.cpp:67
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
virtual InFlightDiagnostic emitError(const Twine &msg={}) const =0
Emit an error to the reader.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual OptionalParseResult parseOptionalAssignmentList(SmallVectorImpl< Argument > &lhs, SmallVectorImpl< UnresolvedOperand > &rhs)=0
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
This class indicates that op operates on tosa shape types.
Definition TosaOps.h:129
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
ResultRange result_range
Support result iteration.
Definition Operation.h:436
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition Operation.h:538
OperandRange operand_range
Definition Operation.h:397
operand_type_range getOperandTypes()
Definition Operation.h:423
result_type_range getResultTypes()
Definition Operation.h:454
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
bool empty()
Definition Region.h:60
Adaptor class to abstract the differences between whether value is from a ShapedType or ShapedTypeCom...
bool isDynamicDim(int index) const
Returns whether the index'th dimension is dynamic.
int64_t getDimSize(int index) const
Returns the size of the index'th dimension.
int64_t getRank() const
Returns the rank of the shape.
bool hasStaticShape() const
Returns whether the shape is fully static.
int64_t getNumElements() const
Returns the number of elements in the shape.
void getDims(SmallVectorImpl< int64_t > &res) const
Populates the dimensions from shape referenced.
bool hasRank() const
Returns whether the shape has a rank.
ShapedTypeComponents that represents the components of a ShapedType.
This class allows for representing and managing the symbol table used by operations with the 'SymbolT...
Definition SymbolTable.h:24
Operation * lookup(StringRef name) const
Look up a symbol with the specified name, returning null if no such name exists.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
Range of values and shapes (corresponding effectively to Shapes dialect's ValueShape type concept).
ShapeAdaptor getShape(int index) const
Returns the shape of index'th operand.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands)
LogicalResult verifyTosaShapeOperatorWithSameRanks(Operation *op)
Definition TosaOps.cpp:5609
LogicalResult verifyTosaResolvableShapeOperands(Operation *op)
Definition TosaOps.cpp:5596
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< unsigned > getBlockSize(AffineMap dimToLvl)
Given the dimToLvl map, returns the block sizes in a vector.
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight)
Method to build ConvOpQuantizationAttr, called from ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilde...
RankedTensorType getVariableType(VariableOp variableOp)
Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight)
construct ConvOp output type with correct bitwidth based on input/weight width.
ParseResult parseVariableOpTypeOrInitialValue(OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr)
Definition TosaOps.cpp:228
PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, Value input)
Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: inputZp: input zeropoint.
constexpr int64_t kInferableDimSize
Represents a dimension in the shape of a tensor that can be inferred based on the other provided dime...
Definition TosaOps.h:153
std::pair< Value, Value > createZPsAsConst(OpBuilder &builder, Value input, Value weight)
void printVariableOpTypeOrInitialValue(OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr)
Definition TosaOps.cpp:253
Value getTosaConstShape(ImplicitLocOpBuilder &builder, llvm::ArrayRef< int64_t > shape)
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b)
Builds MatMulOpQuantizationAttr, called from MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: in...
unsigned getBitWidth(Type type)
Definition TosaOps.cpp:630
std::optional< Value > createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp=0)
Definition TosaOps.cpp:5560
bool isa_tosa_shape_type(mlir::Type t)
Definition TosaOps.cpp:5584
SmallVector< int64_t > convertFromMlirShape(ArrayRef< int64_t > shape)
UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType)
Builds UnaryOpQuantizationAttr UnaryOpQuantInfoBuilder: inputZp: input zeropoint outputZp: output zer...
Type getStorageElementTypeFromQuantized(quant::QuantizedType quantizedType)
Value createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val=0)
Definition TosaOps.cpp:615
bool getConstShapeValues(Operation *op, llvm::SmallVector< int64_t > &result_shape)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
LogicalResult verifyCompatibleDims(ArrayRef< int64_t > dims)
Dimensions are compatible if all non-dynamic dims are equal.
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
static ValueKnowledge meet(const ValueKnowledge &lhs, const ValueKnowledge &rhs)
Definition ShapeUtils.h:136
static ValueKnowledge getKnowledgeFromType(Type type)
Definition ShapeUtils.h:45